-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[AArch64] Optimized rdsvl followed by constant mul #162853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: None (Lukacma) ChangesCurrently when RDSVL is followed by constant multiplication, no specific optimization exist which would leverage the immediate multiplication operand to generate simpler assembly. This patch adds such optimization and allow rewrites like these if certain conditions are met: Full diff: https://github.com/llvm/llvm-project/pull/162853.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dc8e7c84f5e2c..1877b13a27c30 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -19579,6 +19579,47 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
if (ConstValue.sge(1) && ConstValue.sle(16))
return SDValue();
+ // Multiplying an RDSVL value by a constant can sometimes be done cheaper by
+ // folding a power-of-two factor of the constant into the RDSVL immediate and
+ // compensating with an extra shift.
+ //
+ // We rewrite:
+ // (mul (srl (rdsvl 1), 3), x)
+ // to one of:
+ // (shl (rdsvl y), z) if z > 0
+ // (srl (rdsvl y), abs(z)) if z < 0
+ // where integers y, z satisfy x = y * 2^(3 + z) and y ∈ [-32, 31].
+ if ((N0->getOpcode() == ISD::SRL) &&
+ (N0->getOperand(0).getOpcode() == AArch64ISD::RDSVL)) {
+ unsigned AbsConstValue = ConstValue.abs().getZExtValue();
+
+ // z ≤ ctz(|x|) - 3 (largest extra shift we can take while keeping y
+ // integral)
+ int UpperBound = llvm::countr_zero(AbsConstValue) - 3;
+
+ // To keep y in range, with B = 31 for x > 0 and B = 32 for x < 0, we need:
+ // 2^(3 + z) ≥ ceil(x / B) ⇒ z ≥ ceil_log2(ceil(x / B)) - 3 (LowerBound).
+ unsigned B = ConstValue.isNegative() ? 32 : 31;
+ unsigned CeilAxOverB = (AbsConstValue + (B - 1)) / B; // ceil(|x|/B)
+ int LowerBound = llvm::Log2_32_Ceil(CeilAxOverB) - 3;
+
+ // If solution exists, apply optimization.
+ if (LowerBound <= UpperBound) {
+
+ int Shift = std::min(std::max(/*prefer*/ 0, LowerBound), UpperBound);
+ int32_t RdsvlMul =
+ (AbsConstValue >> (3 + Shift)) * (ConstValue.isNegative() ? -1 : 1);
+ auto Rdsvl = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
+ DAG.getSignedConstant(RdsvlMul, DL, MVT::i32));
+
+ if (Shift == 0)
+ return Rdsvl;
+ return DAG.getNode(Shift < 0 ? ISD::SRL : ISD::SHL, DL, VT, Rdsvl,
+ DAG.getConstant(abs(Shift), DL, MVT::i32),
+ SDNodeFlags::Exact);
+ }
+ }
+
// Multiplication of a power of two plus/minus one can be done more
// cheaply as shift+add/sub. For now, this is true unilaterally. If
// future CPUs have a cheaper MADD instruction, this may need to be
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
index 06c53d8070781..ea0057e4cfdef 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
@@ -86,4 +86,98 @@ define i64 @sme_cntsd_mul() {
ret i64 %res
}
+define i64 @sme_cntsb_mul_pos() {
+; CHECK-LABEL: sme_cntsb_mul_pos:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #24
+; CHECK-NEXT: lsl x0, x8, #2
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %shl = shl nuw nsw i64 %v, 3
+ %res = mul nuw nsw i64 %shl, 96
+ ret i64 %res
+}
+
+define i64 @sme_cntsh_mul_pos() {
+; CHECK-LABEL: sme_cntsh_mul_pos:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #3
+; CHECK-NEXT: lsr x0, x8, #1
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %shl = shl nuw nsw i64 %v, 2
+ %res = mul nuw nsw i64 %shl, 3
+ ret i64 %res
+}
+
+define i64 @sme_cntsw_mul_pos() {
+; CHECK-LABEL: sme_cntsw_mul_pos:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #31
+; CHECK-NEXT: lsr x0, x8, #1
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %shl = shl nuw nsw i64 %v, 1
+ %res = mul nuw nsw i64 %shl, 62
+ ret i64 %res
+}
+
+define i64 @sme_cntsd_mul_pos() {
+; CHECK-LABEL: sme_cntsd_mul_pos:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #31
+; CHECK-NEXT: lsl x0, x8, #2
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %res = mul nuw nsw i64 %v, 992
+ ret i64 %res
+}
+
+define i64 @sme_cntsb_mul_neg() {
+; CHECK-LABEL: sme_cntsb_mul_neg:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #-24
+; CHECK-NEXT: lsl x0, x8, #2
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %shl = shl nuw nsw i64 %v, 3
+ %res = mul nuw nsw i64 %shl, -96
+ ret i64 %res
+}
+
+define i64 @sme_cntsh_mul_neg() {
+; CHECK-LABEL: sme_cntsh_mul_neg:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #-3
+; CHECK-NEXT: lsr x0, x8, #1
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %shl = shl nuw nsw i64 %v, 2
+ %res = mul nuw nsw i64 %shl, -3
+ ret i64 %res
+}
+
+define i64 @sme_cntsw_mul_neg() {
+; CHECK-LABEL: sme_cntsw_mul_neg:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #-31
+; CHECK-NEXT: lsl x0, x8, #3
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %shl = shl nuw nsw i64 %v, 1
+ %res = mul nuw nsw i64 %shl, -992
+ ret i64 %res
+}
+
+define i64 @sme_cntsd_mul_neg() {
+; CHECK-LABEL: sme_cntsd_mul_neg:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #-3
+; CHECK-NEXT: lsr x0, x8, #3
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %res = mul nuw nsw i64 %v, -3
+ ret i64 %res
+}
+
declare i64 @llvm.aarch64.sme.cntsd()
|
Currently when RDSVL is followed by constant multiplication, no specific optimization exist which would leverage the immediate multiplication operand to generate simpler assembly. This patch adds such optimization and allow rewrites like these if certain conditions are met:
(mul (srl (rdsvl 1), 3), x) -> (shl (rdsvl y), z)