Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19476,6 +19476,61 @@ static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) {
Op1 ? Op1 : Mul->getOperand(1));
}

// Multiplying an RDSVL value by a constant can sometimes be done cheaper by
// folding a power-of-two factor of the constant into the RDSVL immediate and
// compensating with an extra shift.
//
// We rewrite:
// (mul (srl (rdsvl 1), w), x)
// to one of:
// (shl (rdsvl y), z) if z > 0
// (srl (rdsvl y), abs(z)) if z < 0
// where integers y, z satisfy x = y * 2^(w + z) and y ∈ [-32, 31].
static SDValue performMulRdsvlCombine(SDNode *Mul, SelectionDAG &DAG) {
SDLoc DL(Mul);
EVT VT = Mul->getValueType(0);
SDValue MulOp0 = Mul->getOperand(0);
int ConstMultiplier =
cast<ConstantSDNode>(Mul->getOperand(1))->getSExtValue();
if ((MulOp0->getOpcode() != ISD::SRL) ||
(MulOp0->getOperand(0).getOpcode() != AArch64ISD::RDSVL))
return SDValue();

unsigned AbsConstValue = abs(ConstMultiplier);
unsigned OperandShift =
cast<ConstantSDNode>(MulOp0->getOperand(1))->getZExtValue();

// z ≤ ctz(|x|) - w (largest extra shift we can take while keeping y
// integral)
int UpperBound = llvm::countr_zero(AbsConstValue) - OperandShift;

// To keep y in range, with B = 31 for x > 0 and B = 32 for x < 0, we need:
// 2^(w + z) ≥ ceil(x / B) ⇒ z ≥ ceil_log2(ceil(x / B)) - w (LowerBound).
unsigned B = ConstMultiplier < 0 ? 32 : 31;
unsigned CeilAxOverB = (AbsConstValue + (B - 1)) / B; // ceil(|x|/B)
int LowerBound = llvm::Log2_32_Ceil(CeilAxOverB) - OperandShift;

// No valid solution found.
if (LowerBound > UpperBound)
return SDValue();

// Any value of z in [LowerBound, UpperBound] is valid. Prefer no extra
// shift if possible.
int Shift = std::min(std::max(/*prefer*/ 0, LowerBound), UpperBound);

// y = x / 2^(w + z)
int32_t RdsvlMul = (AbsConstValue >> (OperandShift + Shift)) *
(ConstMultiplier < 0 ? -1 : 1);
auto Rdsvl = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
DAG.getSignedConstant(RdsvlMul, DL, MVT::i32));

if (Shift == 0)
return Rdsvl;
return DAG.getNode(Shift < 0 ? ISD::SRL : ISD::SHL, DL, VT, Rdsvl,
DAG.getConstant(abs(Shift), DL, MVT::i32),
SDNodeFlags::Exact);
}

// Combine v4i32 Mul(And(Srl(X, 15), 0x10001), 0xffff) -> v8i16 CMLTz
// Same for other types with equivalent constants.
static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) {
Expand Down Expand Up @@ -19604,6 +19659,9 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
if (!isa<ConstantSDNode>(N1))
return SDValue();

if (SDValue Ext = performMulRdsvlCombine(N, DAG))
return Ext;

ConstantSDNode *C = cast<ConstantSDNode>(N1);
const APInt &ConstValue = C->getAPIntValue();

Expand Down
109 changes: 108 additions & 1 deletion llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,111 @@ define i64 @sme_cntsd_mul() {
ret i64 %res
}

declare i64 @llvm.aarch64.sme.cntsd()
define i64 @sme_cntsb_mul_pos() {
; CHECK-LABEL: sme_cntsb_mul_pos:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #24
; CHECK-NEXT: lsl x0, x8, #2
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%shl = shl nuw nsw i64 %v, 3
%res = mul nuw nsw i64 %shl, 96
ret i64 %res
}

define i64 @sme_cntsh_mul_pos() {
; CHECK-LABEL: sme_cntsh_mul_pos:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #3
; CHECK-NEXT: lsr x0, x8, #1
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%shl = shl nuw nsw i64 %v, 2
%res = mul nuw nsw i64 %shl, 3
ret i64 %res
}

define i64 @sme_cntsw_mul_pos() {
; CHECK-LABEL: sme_cntsw_mul_pos:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #31
; CHECK-NEXT: lsr x0, x8, #1
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%shl = shl nuw nsw i64 %v, 1
%res = mul nuw nsw i64 %shl, 62
ret i64 %res
}

define i64 @sme_cntsd_mul_pos() {
; CHECK-LABEL: sme_cntsd_mul_pos:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #31
; CHECK-NEXT: lsl x0, x8, #2
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%res = mul nuw nsw i64 %v, 992
ret i64 %res
}

define i64 @sme_cntsb_mul_neg() {
; CHECK-LABEL: sme_cntsb_mul_neg:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #-24
; CHECK-NEXT: lsl x0, x8, #2
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%shl = shl nuw nsw i64 %v, 3
%res = mul nuw nsw i64 %shl, -96
ret i64 %res
}

define i64 @sme_cntsh_mul_neg() {
; CHECK-LABEL: sme_cntsh_mul_neg:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #-3
; CHECK-NEXT: lsr x0, x8, #1
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%shl = shl nuw nsw i64 %v, 2
%res = mul nuw nsw i64 %shl, -3
ret i64 %res
}

define i64 @sme_cntsw_mul_neg() {
; CHECK-LABEL: sme_cntsw_mul_neg:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #-31
; CHECK-NEXT: lsl x0, x8, #3
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%shl = shl nuw nsw i64 %v, 1
%res = mul nuw nsw i64 %shl, -992
ret i64 %res
}

define i64 @sme_cntsd_mul_neg() {
; CHECK-LABEL: sme_cntsd_mul_neg:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #-3
; CHECK-NEXT: lsr x0, x8, #3
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%res = mul nuw nsw i64 %v, -3
ret i64 %res
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please could you add a positive & negative test where the values being multiplied by will be out of range for the rdsvl immediate?

Copy link
Contributor Author

@Lukacma Lukacma Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand here. I have a test where immediate is out of range of RDSVL like here:

 %v = call i64 @llvm.aarch64.sme.cntsd()
  %res = mul nuw nsw i64 %v, 992
  ret i64 %res

Or are you asking for smth else ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't think my question was very clear :)
I was wondering if there was a way to add a test case where we can't apply the optimisation. For example, this is a similar test I tried:

%v = call i64 @llvm.aarch64.sme.cntsd()
%res = mul nuw nsw i64 %v, 993
ret i64 %res

which results in:

rdsvl x8, #1
mov w9, #993
lsr x8, x8, #3
mul x0, x8, x9

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. I added the test.


; Negative test for optimization failure
define i64 @sme_cntsd_mul_fail() {
; CHECK-LABEL: sme_cntsd_mul_fail:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov w9, #993 // =0x3e1
; CHECK-NEXT: lsr x8, x8, #3
; CHECK-NEXT: mul x0, x8, x9
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%res = mul nuw nsw i64 %v, 993
ret i64 %res
}