Skip to content

Commit 191a4de

Browse files
- Remove cnts[b,h,w] intrinsics from MLIR and fix tests
- Remove ZACount class from arm_sme.td
1 parent 65da718 commit 191a4de

File tree

6 files changed

+40
-37
lines changed

6 files changed

+40
-37
lines changed

clang/include/clang/Basic/arm_sme.td

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,10 @@ let SMETargetGuard = "sme2p1" in {
156156
////////////////////////////////////////////////////////////////////////////////
157157
// SME - Counting elements in a streaming vector
158158

159-
multiclass ZACount<string intr, string n_suffix> {
160-
def NAME : SInst<"sv"#n_suffix, "nv", "", MergeNone,
161-
intr, [IsOverloadNone, IsStreamingCompatible]>;
162-
}
163-
164-
defm SVCNTSB : ZACount<"", "cntsb">;
165-
defm SVCNTSH : ZACount<"", "cntsh">;
166-
defm SVCNTSW : ZACount<"", "cntsw">;
167-
defm SVCNTSD : ZACount<"aarch64_sme_cntsd", "cntsd">;
159+
def SVCNTSB : SInst<"svcntsb", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>;
160+
def SVCNTSH : SInst<"svcntsh", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>;
161+
def SVCNTSW : SInst<"svcntsw", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>;
162+
def SVCNTSD : SInst<"svcntsd", "nv", "", MergeNone, "aarch64_sme_cntsd", [IsOverloadNone, IsStreamingCompatible]>;
168163

169164
////////////////////////////////////////////////////////////////////////////////
170165
// SME - ADDHA/ADDVA

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,6 @@ class ArmSME_IntrCountOp<string mnemonic>
201201
/*traits*/[PredOpTrait<"`res` is i64", TypeIsPred<"res", I64>>],
202202
/*numResults=*/1, /*overloadedResults=*/[]>;
203203

204-
def LLVM_aarch64_sme_cntsb : ArmSME_IntrCountOp<"cntsb">;
205-
def LLVM_aarch64_sme_cntsh : ArmSME_IntrCountOp<"cntsh">;
206-
def LLVM_aarch64_sme_cntsw : ArmSME_IntrCountOp<"cntsw">;
207204
def LLVM_aarch64_sme_cntsd : ArmSME_IntrCountOp<"cntsd">;
208205

209206
#endif // ARMSME_INTRINSIC_OPS

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -822,16 +822,18 @@ struct OuterProductWideningOpConversion
822822
}
823823
};
824824

825-
/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
825+
/// Lower `arm_sme.streaming_vl` to SME CNTSD intrinsic.
826826
///
827827
/// Example:
828828
///
829829
/// %0 = arm_sme.streaming_vl <half>
830830
///
831831
/// is converted to:
832832
///
833-
/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
834-
/// %0 = arith.index_cast %cnt : i64 to index
833+
/// %cnt = "arm_sme.intr.cntsd"() : () -> i64
834+
/// %0 = arith.constant 4 : i64
835+
/// %1 = arith.muli %cnt, %0 : i64
836+
/// %2 = arith.index_cast %1 : i64 to index
835837
///
836838
struct StreamingVLOpConversion
837839
: public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
@@ -845,15 +847,25 @@ struct StreamingVLOpConversion
845847
auto loc = streamingVlOp.getLoc();
846848
auto i64Type = rewriter.getI64Type();
847849
auto *intrOp = [&]() -> Operation * {
850+
auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
848851
switch (streamingVlOp.getTypeSize()) {
849-
case arm_sme::TypeSize::Byte:
850-
return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type);
851-
case arm_sme::TypeSize::Half:
852-
return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type);
853-
case arm_sme::TypeSize::Word:
854-
return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type);
852+
case arm_sme::TypeSize::Byte: {
853+
auto mul = arith::ConstantIndexOp::create(rewriter, loc, 8);
854+
auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
855+
return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
856+
}
857+
case arm_sme::TypeSize::Half: {
858+
auto mul = arith::ConstantIndexOp::create(rewriter, loc, 4);
859+
auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
860+
return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
861+
}
862+
case arm_sme::TypeSize::Word: {
863+
auto mul = arith::ConstantIndexOp::create(rewriter, loc, 2);
864+
auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
865+
return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
866+
}
855867
case arm_sme::TypeSize::Double:
856-
return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
868+
return cntsd;
857869
}
858870
llvm_unreachable("unknown type size in StreamingVLOpConversion");
859871
}();
@@ -964,9 +976,7 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
964976
arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
965977
arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
966978
arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
967-
arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
968-
arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
969-
arm_sme::aarch64_sme_cntsd>();
979+
arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsd>();
970980
target.addLegalDialect<arith::ArithDialect,
971981
/* The following are used to lower tile spills/fills */
972982
vector::VectorDialect, scf::SCFDialect,

mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -586,9 +586,10 @@ func.func @arm_sme_extract_tile_slice_ver_i128(%tile_slice_index : index) -> vec
586586
// -----
587587

588588
// CHECK-LABEL: @arm_sme_streaming_vl_bytes
589-
// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
590-
// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
591-
// CHECK: return %[[INDEX_COUNT]] : index
589+
// CHECK: %[[CONST:.*]] = arith.constant 8 : i64
590+
// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
591+
// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
592+
// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
592593
func.func @arm_sme_streaming_vl_bytes() -> index {
593594
%svl_b = arm_sme.streaming_vl <byte>
594595
return %svl_b : index
@@ -597,7 +598,10 @@ func.func @arm_sme_streaming_vl_bytes() -> index {
597598
// -----
598599

599600
// CHECK-LABEL: @arm_sme_streaming_vl_half_words
600-
// CHECK: "arm_sme.intr.cntsh"() : () -> i64
601+
// CHECK: %[[CONST:.*]] = arith.constant 4 : i64
602+
// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
603+
// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
604+
// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
601605
func.func @arm_sme_streaming_vl_half_words() -> index {
602606
%svl_h = arm_sme.streaming_vl <half>
603607
return %svl_h : index
@@ -606,7 +610,10 @@ func.func @arm_sme_streaming_vl_half_words() -> index {
606610
// -----
607611

608612
// CHECK-LABEL: @arm_sme_streaming_vl_words
609-
// CHECK: "arm_sme.intr.cntsw"() : () -> i64
613+
// CHECK: %[[CONST:.*]] = arith.constant 2 : i64
614+
// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
615+
// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
616+
// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
610617
func.func @arm_sme_streaming_vl_words() -> index {
611618
%svl_w = arm_sme.streaming_vl <word>
612619
return %svl_w : index

mlir/test/Target/LLVMIR/arm-sme-invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
3636

3737
llvm.func @arm_sme_streaming_vl_invalid_return_type() -> i32 {
3838
// expected-error @+1 {{failed to verify that `res` is i64}}
39-
%res = "arm_sme.intr.cntsb"() : () -> i32
39+
%res = "arm_sme.intr.cntsd"() : () -> i32
4040
llvm.return %res : i32
4141
}

mlir/test/Target/LLVMIR/arm-sme.mlir

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -419,12 +419,6 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
419419
// -----
420420

421421
llvm.func @arm_sme_streaming_vl() {
422-
// CHECK: call i64 @llvm.aarch64.sme.cntsb()
423-
%svl_b = "arm_sme.intr.cntsb"() : () -> i64
424-
// CHECK: call i64 @llvm.aarch64.sme.cntsh()
425-
%svl_h = "arm_sme.intr.cntsh"() : () -> i64
426-
// CHECK: call i64 @llvm.aarch64.sme.cntsw()
427-
%svl_w = "arm_sme.intr.cntsw"() : () -> i64
428422
// CHECK: call i64 @llvm.aarch64.sme.cntsd()
429423
%svl_d = "arm_sme.intr.cntsd"() : () -> i64
430424
llvm.return

0 commit comments

Comments
 (0)