Skip to content

Commit 65da718

Browse files
- Replace cnts[b|h|w] builtins with cntsd intrinsic in Clang
- Remove cnts[b|h|w] intrinsics in LLVM - Add patterns for cntsd
1 parent fd8ff8a commit 65da718

File tree

12 files changed

+136
-158
lines changed

12 files changed

+136
-158
lines changed

clang/include/clang/Basic/arm_sme.td

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,15 @@ let SMETargetGuard = "sme2p1" in {
156156
////////////////////////////////////////////////////////////////////////////////
157157
// SME - Counting elements in a streaming vector
158158

159-
multiclass ZACount<string n_suffix> {
160-
def NAME : SInst<"sv" # n_suffix, "nv", "", MergeNone,
161-
"aarch64_sme_" # n_suffix,
162-
[IsOverloadNone, IsStreamingCompatible]>;
159+
multiclass ZACount<string intr, string n_suffix> {
160+
def NAME : SInst<"sv"#n_suffix, "nv", "", MergeNone,
161+
intr, [IsOverloadNone, IsStreamingCompatible]>;
163162
}
164163

165-
defm SVCNTSB : ZACount<"cntsb">;
166-
defm SVCNTSH : ZACount<"cntsh">;
167-
defm SVCNTSW : ZACount<"cntsw">;
168-
defm SVCNTSD : ZACount<"cntsd">;
164+
defm SVCNTSB : ZACount<"", "cntsb">;
165+
defm SVCNTSH : ZACount<"", "cntsh">;
166+
defm SVCNTSW : ZACount<"", "cntsw">;
167+
defm SVCNTSD : ZACount<"aarch64_sme_cntsd", "cntsd">;
169168

170169
////////////////////////////////////////////////////////////////////////////////
171170
// SME - ADDHA/ADDVA

clang/lib/CodeGen/TargetBuiltins/ARM.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4304,9 +4304,10 @@ Value *CodeGenFunction::EmitSMELd1St1(const SVETypeFlags &TypeFlags,
43044304
// size in bytes.
43054305
if (Ops.size() == 5) {
43064306
Function *StreamingVectorLength =
4307-
CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsb);
4307+
CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd);
43084308
llvm::Value *StreamingVectorLengthCall =
4309-
Builder.CreateCall(StreamingVectorLength);
4309+
Builder.CreateMul(Builder.CreateCall(StreamingVectorLength),
4310+
llvm::ConstantInt::get(Int64Ty, 8), "svl");
43104311
llvm::Value *Mulvl =
43114312
Builder.CreateMul(StreamingVectorLengthCall, Ops[4], "mulvl");
43124313
// The type of the ptr parameter is void *, so use Int8Ty here.
@@ -4918,6 +4919,31 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
49184919
// Handle builtins which require their multi-vector operands to be swapped
49194920
swapCommutativeSMEOperands(BuiltinID, Ops);
49204921

4922+
auto isCntsBuiltin = [&](int64_t &Mul) {
4923+
switch (BuiltinID) {
4924+
default:
4925+
Mul = 0;
4926+
return false;
4927+
case SME::BI__builtin_sme_svcntsb:
4928+
Mul = 8;
4929+
return true;
4930+
case SME::BI__builtin_sme_svcntsh:
4931+
Mul = 4;
4932+
return true;
4933+
case SME::BI__builtin_sme_svcntsw:
4934+
Mul = 2;
4935+
return true;
4936+
}
4937+
};
4938+
4939+
int64_t Mul = 0;
4940+
if (isCntsBuiltin(Mul)) {
4941+
llvm::Value *Cntd =
4942+
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd));
4943+
return Builder.CreateMul(Cntd, llvm::ConstantInt::get(Int64Ty, Mul),
4944+
"mulsvl", /* HasNUW */ true, /* HasNSW */ true);
4945+
}
4946+
49214947
// Should not happen!
49224948
if (Builtin->LLVMIntrinsic == 0)
49234949
return nullptr;

clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,55 @@
66

77
#include <arm_sme.h>
88

9-
// CHECK-C-LABEL: define dso_local i64 @test_svcntsb(
9+
// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsb(
1010
// CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] {
1111
// CHECK-C-NEXT: entry:
12-
// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
13-
// CHECK-C-NEXT: ret i64 [[TMP0]]
12+
// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
13+
// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 3
14+
// CHECK-C-NEXT: ret i64 [[MULSVL]]
1415
//
15-
// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntsbv(
16+
// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntsbv(
1617
// CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] {
1718
// CHECK-CXX-NEXT: entry:
18-
// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
19-
// CHECK-CXX-NEXT: ret i64 [[TMP0]]
19+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
20+
// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 3
21+
// CHECK-CXX-NEXT: ret i64 [[MULSVL]]
2022
//
2123
uint64_t test_svcntsb() {
2224
return svcntsb();
2325
}
2426

25-
// CHECK-C-LABEL: define dso_local i64 @test_svcntsh(
27+
// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsh(
2628
// CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] {
2729
// CHECK-C-NEXT: entry:
28-
// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsh()
29-
// CHECK-C-NEXT: ret i64 [[TMP0]]
30+
// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
31+
// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 2
32+
// CHECK-C-NEXT: ret i64 [[MULSVL]]
3033
//
31-
// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntshv(
34+
// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntshv(
3235
// CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] {
3336
// CHECK-CXX-NEXT: entry:
34-
// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsh()
35-
// CHECK-CXX-NEXT: ret i64 [[TMP0]]
37+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
38+
// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 2
39+
// CHECK-CXX-NEXT: ret i64 [[MULSVL]]
3640
//
3741
uint64_t test_svcntsh() {
3842
return svcntsh();
3943
}
4044

41-
// CHECK-C-LABEL: define dso_local i64 @test_svcntsw(
45+
// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsw(
4246
// CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] {
4347
// CHECK-C-NEXT: entry:
44-
// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsw()
45-
// CHECK-C-NEXT: ret i64 [[TMP0]]
48+
// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
49+
// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 1
50+
// CHECK-C-NEXT: ret i64 [[MULSVL]]
4651
//
47-
// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntswv(
52+
// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntswv(
4853
// CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] {
4954
// CHECK-CXX-NEXT: entry:
50-
// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsw()
51-
// CHECK-CXX-NEXT: ret i64 [[TMP0]]
55+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
56+
// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 1
57+
// CHECK-CXX-NEXT: ret i64 [[MULSVL]]
5258
//
5359
uint64_t test_svcntsw() {
5460
return svcntsw();

llvm/include/llvm/IR/IntrinsicsAArch64.td

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3147,13 +3147,8 @@ let TargetPrefix = "aarch64" in {
31473147
// Counting elements
31483148
//
31493149

3150-
class AdvSIMD_SME_CNTSB_Intrinsic
3151-
: DefaultAttrsIntrinsic<[llvm_i64_ty], [], [IntrNoMem]>;
3152-
3153-
def int_aarch64_sme_cntsb : AdvSIMD_SME_CNTSB_Intrinsic;
3154-
def int_aarch64_sme_cntsh : AdvSIMD_SME_CNTSB_Intrinsic;
3155-
def int_aarch64_sme_cntsw : AdvSIMD_SME_CNTSB_Intrinsic;
3156-
def int_aarch64_sme_cntsd : AdvSIMD_SME_CNTSB_Intrinsic;
3150+
def int_aarch64_sme_cntsd
3151+
: DefaultAttrsIntrinsic<[llvm_i64_ty], [], [IntrNoMem]>;
31573152

31583153
//
31593154
// PSTATE Functions

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
7171
template <signed Low, signed High, signed Scale>
7272
bool SelectRDVLImm(SDValue N, SDValue &Imm);
7373

74+
template <signed Low, signed High>
75+
bool SelectRDSVLShiftImm(SDValue N, SDValue &Imm);
76+
7477
bool SelectArithExtendedRegister(SDValue N, SDValue &Reg, SDValue &Shift);
7578
bool SelectArithUXTXRegister(SDValue N, SDValue &Reg, SDValue &Shift);
7679
bool SelectArithImmed(SDValue N, SDValue &Val, SDValue &Shift);
@@ -937,6 +940,23 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) {
937940
return false;
938941
}
939942

943+
template <signed Low, signed High>
944+
bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) {
945+
if (!isa<ConstantSDNode>(N))
946+
return false;
947+
948+
int64_t ShlImm = cast<ConstantSDNode>(N)->getSExtValue();
949+
if (ShlImm >= 3) {
950+
int64_t MulImm = 1 << (ShlImm - 3);
951+
if (MulImm >= Low && MulImm <= High) {
952+
Imm = CurDAG->getSignedTargetConstant(MulImm, SDLoc(N), MVT::i32);
953+
return true;
954+
}
955+
}
956+
957+
return false;
958+
}
959+
940960
/// SelectArithExtendedRegister - Select a "extended register" operand. This
941961
/// operand folds in an extend followed by an optional left shift.
942962
bool AArch64DAGToDAGISel::SelectArithExtendedRegister(SDValue N, SDValue &Reg,

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6266,27 +6266,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
62666266
case Intrinsic::aarch64_sve_clz:
62676267
return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, DL, Op.getValueType(),
62686268
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
6269-
case Intrinsic::aarch64_sme_cntsb: {
6270-
SDValue Cntd = DAG.getNode(
6271-
ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
6272-
DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
6273-
return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
6274-
DAG.getConstant(8, DL, MVT::i64));
6275-
}
6276-
case Intrinsic::aarch64_sme_cntsh: {
6277-
SDValue Cntd = DAG.getNode(
6278-
ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
6279-
DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
6280-
return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
6281-
DAG.getConstant(4, DL, MVT::i64));
6282-
}
6283-
case Intrinsic::aarch64_sme_cntsw: {
6284-
SDValue Cntd = DAG.getNode(
6285-
ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
6286-
DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
6287-
return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
6288-
DAG.getConstant(2, DL, MVT::i64));
6289-
}
62906269
case Intrinsic::aarch64_sve_cnt: {
62916270
SDValue Data = Op.getOperand(3);
62926271
// CTPOP only supports integer operands.

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,12 @@ def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
127127
def SDT_AArch64RDSVL : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
128128
def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>;
129129

130-
def sme_cntsb_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
131-
def sme_cntsh_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
132-
def sme_cntsw_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
133-
def sme_cntsd_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
130+
def sme_cntsb_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
131+
def sme_cntsh_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
132+
def sme_cntsw_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
133+
def sme_cntsd_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
134+
135+
def sme_cnts_shl_imm : ComplexPattern<i64, 1, "SelectRDSVLShiftImm<1, 31>">;
134136

135137
let Predicates = [HasSMEandIsNonStreamingSafe] in {
136138
def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>;
@@ -140,21 +142,21 @@ def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>;
140142
def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
141143

142144
// e.g. cntsb() * imm
143-
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_imm i64:$imm))),
145+
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_mul_imm i64:$imm))),
144146
(RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
145-
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_imm i64:$imm))),
147+
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_mul_imm i64:$imm))),
146148
(UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 1, 63)>;
147-
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_imm i64:$imm))),
149+
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_mul_imm i64:$imm))),
148150
(UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 2, 63)>;
149-
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_imm i64:$imm))),
151+
def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_mul_imm i64:$imm))),
150152
(UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 3, 63)>;
151153

152-
// e.g. cntsb()
153-
def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>;
154-
def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>;
155-
def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 3))), (RDSVLI_XI 1)>;
154+
def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (sme_cnts_shl_imm i64:$imm))),
155+
(RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
156156

157-
// Generic pattern for cntsd (RDSVL #1 >> 3)
157+
// cntsh, cntsw, cntsd
158+
def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>;
159+
def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>;
158160
def : Pat<(i64 (int_aarch64_sme_cntsd)), (UBFMXri (RDSVLI_XI 1), 3, 63)>;
159161
}
160162

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,15 +2102,15 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
21022102
}
21032103

21042104
static std::optional<Instruction *>
2105-
instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts,
2105+
instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II,
21062106
const AArch64Subtarget *ST) {
21072107
if (!ST->isStreaming())
21082108
return std::nullopt;
21092109

2110-
// In streaming-mode, aarch64_sme_cnts is equivalent to aarch64_sve_cnt
2110+
// In streaming-mode, aarch64_sme_cntds is equivalent to aarch64_sve_cntd
21112111
// with SVEPredPattern::all
2112-
Value *Cnt = IC.Builder.CreateElementCount(
2113-
II.getType(), ElementCount::getScalable(NumElts));
2112+
Value *Cnt =
2113+
IC.Builder.CreateElementCount(II.getType(), ElementCount::getScalable(2));
21142114
Cnt->takeName(&II);
21152115
return IC.replaceInstUsesWith(II, Cnt);
21162116
}
@@ -2825,13 +2825,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
28252825
case Intrinsic::aarch64_sve_cntb:
28262826
return instCombineSVECntElts(IC, II, 16);
28272827
case Intrinsic::aarch64_sme_cntsd:
2828-
return instCombineSMECntsElts(IC, II, 2, ST);
2829-
case Intrinsic::aarch64_sme_cntsw:
2830-
return instCombineSMECntsElts(IC, II, 4, ST);
2831-
case Intrinsic::aarch64_sme_cntsh:
2832-
return instCombineSMECntsElts(IC, II, 8, ST);
2833-
case Intrinsic::aarch64_sme_cntsb:
2834-
return instCombineSMECntsElts(IC, II, 16, ST);
2828+
return instCombineSMECntsElts(IC, II, ST);
28352829
case Intrinsic::aarch64_sve_ptest_any:
28362830
case Intrinsic::aarch64_sve_ptest_first:
28372831
case Intrinsic::aarch64_sve_ptest_last:

0 commit comments

Comments
 (0)