Skip to content

Commit e2d91e5

Browse files
- Remove lambda from StreamingVLOpConversion
- Add getSizeInBytes helper
1 parent 191a4de commit e2d91e5

File tree

7 files changed

+46
-48
lines changed

7 files changed

+46
-48
lines changed

clang/lib/CodeGen/TargetBuiltins/ARM.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4919,25 +4919,20 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
49194919
// Handle builtins which require their multi-vector operands to be swapped
49204920
swapCommutativeSMEOperands(BuiltinID, Ops);
49214921

4922-
auto isCntsBuiltin = [&](int64_t &Mul) {
4922+
auto isCntsBuiltin = [&]() {
49234923
switch (BuiltinID) {
49244924
default:
4925-
Mul = 0;
4926-
return false;
4925+
return 0;
49274926
case SME::BI__builtin_sme_svcntsb:
4928-
Mul = 8;
4929-
return true;
4927+
return 8;
49304928
case SME::BI__builtin_sme_svcntsh:
4931-
Mul = 4;
4932-
return true;
4929+
return 4;
49334930
case SME::BI__builtin_sme_svcntsw:
4934-
Mul = 2;
4935-
return true;
4931+
return 2;
49364932
}
49374933
};
49384934

4939-
int64_t Mul = 0;
4940-
if (isCntsBuiltin(Mul)) {
4935+
if (auto Mul = isCntsBuiltin()) {
49414936
llvm::Value *Cntd =
49424937
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd));
49434938
return Builder.CreateMul(Cntd, llvm::ConstantInt::get(Int64Ty, Mul),

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,9 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) {
940940
return false;
941941
}
942942

943+
// Given cntsd = (rdsvl, #1) >> 3, attempt to return a suitable multiplier
944+
// for RDSVL to calculate the streaming vector length in bytes * N. i.e.
945+
// rdsvl, #(ShlImm - 3)
943946
template <signed Low, signed High>
944947
bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) {
945948
if (!isa<ConstantSDNode>(N))

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,8 +2102,8 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
21022102
}
21032103

21042104
static std::optional<Instruction *>
2105-
instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II,
2106-
const AArch64Subtarget *ST) {
2105+
instCombineSMECntsd(InstCombiner &IC, IntrinsicInst &II,
2106+
const AArch64Subtarget *ST) {
21072107
if (!ST->isStreaming())
21082108
return std::nullopt;
21092109

@@ -2825,7 +2825,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
28252825
case Intrinsic::aarch64_sve_cntb:
28262826
return instCombineSVECntElts(IC, II, 16);
28272827
case Intrinsic::aarch64_sme_cntsd:
2828-
return instCombineSMECntsElts(IC, II, ST);
2828+
return instCombineSMECntsd(IC, II, ST);
28292829
case Intrinsic::aarch64_sve_ptest_any:
28302830
case Intrinsic::aarch64_sve_ptest_first:
28312831
case Intrinsic::aarch64_sve_ptest_last:

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ namespace mlir::arm_sme {
3232

3333
constexpr unsigned MinStreamingVectorLengthInBits = 128;
3434

35+
/// Return the size represented by arm_sme::TypeSize in bytes.
36+
unsigned getSizeInBytes(TypeSize type);
37+
3538
/// Return minimum number of elements for the given element `type` in
3639
/// a vector of SVL bits.
3740
unsigned getSMETileSliceMinNumElts(Type type);

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -846,31 +846,13 @@ struct StreamingVLOpConversion
846846
ConversionPatternRewriter &rewriter) const override {
847847
auto loc = streamingVlOp.getLoc();
848848
auto i64Type = rewriter.getI64Type();
849-
auto *intrOp = [&]() -> Operation * {
850-
auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
851-
switch (streamingVlOp.getTypeSize()) {
852-
case arm_sme::TypeSize::Byte: {
853-
auto mul = arith::ConstantIndexOp::create(rewriter, loc, 8);
854-
auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
855-
return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
856-
}
857-
case arm_sme::TypeSize::Half: {
858-
auto mul = arith::ConstantIndexOp::create(rewriter, loc, 4);
859-
auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
860-
return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
861-
}
862-
case arm_sme::TypeSize::Word: {
863-
auto mul = arith::ConstantIndexOp::create(rewriter, loc, 2);
864-
auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
865-
return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
866-
}
867-
case arm_sme::TypeSize::Double:
868-
return cntsd;
869-
}
870-
llvm_unreachable("unknown type size in StreamingVLOpConversion");
871-
}();
872-
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
873-
streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
849+
auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
850+
auto cntsdIdx = arith::IndexCastOp::create(rewriter, loc,
851+
rewriter.getIndexType(), cntsd);
852+
auto scale = arith::ConstantIndexOp::create(
853+
rewriter, loc,
854+
8 / arm_sme::getSizeInBytes(streamingVlOp.getTypeSize()));
855+
rewriter.replaceOpWithNewOp<arith::MulIOp>(streamingVlOp, cntsdIdx, scale);
874856
return success();
875857
}
876858
};

mlir/lib/Dialect/ArmSME/IR/Utils.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@
1414

1515
namespace mlir::arm_sme {
1616

17+
unsigned getSizeInBytes(TypeSize type) {
18+
switch (type) {
19+
case arm_sme::TypeSize::Byte:
20+
return 1;
21+
case arm_sme::TypeSize::Half:
22+
return 2;
23+
case arm_sme::TypeSize::Word:
24+
return 4;
25+
case arm_sme::TypeSize::Double:
26+
return 8;
27+
default:
28+
llvm_unreachable("unknown type size");
29+
}
30+
}
31+
1732
unsigned getSMETileSliceMinNumElts(Type type) {
1833
assert(isValidSMETileElementType(type) && "invalid tile type!");
1934
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();

mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -586,10 +586,10 @@ func.func @arm_sme_extract_tile_slice_ver_i128(%tile_slice_index : index) -> vec
586586
// -----
587587

588588
// CHECK-LABEL: @arm_sme_streaming_vl_bytes
589-
// CHECK: %[[CONST:.*]] = arith.constant 8 : i64
589+
// CHECK: %[[CONST:.*]] = arith.constant 8 : index
590590
// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
591-
// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
592-
// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
591+
// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index
592+
// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index
593593
func.func @arm_sme_streaming_vl_bytes() -> index {
594594
%svl_b = arm_sme.streaming_vl <byte>
595595
return %svl_b : index
@@ -598,10 +598,10 @@ func.func @arm_sme_streaming_vl_bytes() -> index {
598598
// -----
599599

600600
// CHECK-LABEL: @arm_sme_streaming_vl_half_words
601-
// CHECK: %[[CONST:.*]] = arith.constant 4 : i64
601+
// CHECK: %[[CONST:.*]] = arith.constant 4 : index
602602
// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
603-
// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
604-
// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
603+
// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index
604+
// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index
605605
func.func @arm_sme_streaming_vl_half_words() -> index {
606606
%svl_h = arm_sme.streaming_vl <half>
607607
return %svl_h : index
@@ -610,10 +610,10 @@ func.func @arm_sme_streaming_vl_half_words() -> index {
610610
// -----
611611

612612
// CHECK-LABEL: @arm_sme_streaming_vl_words
613-
// CHECK: %[[CONST:.*]] = arith.constant 2 : i64
613+
// CHECK: %[[CONST:.*]] = arith.constant 2 : index
614614
// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
615-
// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
616-
// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
615+
// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index
616+
// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index
617617
func.func @arm_sme_streaming_vl_words() -> index {
618618
%svl_w = arm_sme.streaming_vl <word>
619619
return %svl_w : index

0 commit comments

Comments
 (0)