Skip to content

Commit 0e6bb88

Browse files
Lukacmadvbuka
authored andcommitted
[AArch64] Optimized rdsvl followed by constant mul (llvm#162853)
Currently when RDSVL is followed by constant multiplication, no specific optimization exist which would leverage the immediate multiplication operand to generate simpler assembly. This patch adds such optimization and allow rewrites like these if certain conditions are met: `(mul (srl (rdsvl 1), 3), x) -> (shl (rdsvl y), z) `
1 parent 90af493 commit 0e6bb88

File tree

2 files changed

+166
-1
lines changed

2 files changed

+166
-1
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19476,6 +19476,61 @@ static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) {
1947619476
Op1 ? Op1 : Mul->getOperand(1));
1947719477
}
1947819478

19479+
// Multiplying an RDSVL value by a constant can sometimes be done cheaper by
19480+
// folding a power-of-two factor of the constant into the RDSVL immediate and
19481+
// compensating with an extra shift.
19482+
//
19483+
// We rewrite:
19484+
// (mul (srl (rdsvl 1), w), x)
19485+
// to one of:
19486+
// (shl (rdsvl y), z) if z > 0
19487+
// (srl (rdsvl y), abs(z)) if z < 0
19488+
// where integers y, z satisfy x = y * 2^(w + z) and y ∈ [-32, 31].
19489+
static SDValue performMulRdsvlCombine(SDNode *Mul, SelectionDAG &DAG) {
19490+
SDLoc DL(Mul);
19491+
EVT VT = Mul->getValueType(0);
19492+
SDValue MulOp0 = Mul->getOperand(0);
19493+
int ConstMultiplier =
19494+
cast<ConstantSDNode>(Mul->getOperand(1))->getSExtValue();
19495+
if ((MulOp0->getOpcode() != ISD::SRL) ||
19496+
(MulOp0->getOperand(0).getOpcode() != AArch64ISD::RDSVL))
19497+
return SDValue();
19498+
19499+
unsigned AbsConstValue = abs(ConstMultiplier);
19500+
unsigned OperandShift =
19501+
cast<ConstantSDNode>(MulOp0->getOperand(1))->getZExtValue();
19502+
19503+
// z ≤ ctz(|x|) - w (largest extra shift we can take while keeping y
19504+
// integral)
19505+
int UpperBound = llvm::countr_zero(AbsConstValue) - OperandShift;
19506+
19507+
// To keep y in range, with B = 31 for x > 0 and B = 32 for x < 0, we need:
19508+
// 2^(w + z) ≥ ceil(x / B) ⇒ z ≥ ceil_log2(ceil(x / B)) - w (LowerBound).
19509+
unsigned B = ConstMultiplier < 0 ? 32 : 31;
19510+
unsigned CeilAxOverB = (AbsConstValue + (B - 1)) / B; // ceil(|x|/B)
19511+
int LowerBound = llvm::Log2_32_Ceil(CeilAxOverB) - OperandShift;
19512+
19513+
// No valid solution found.
19514+
if (LowerBound > UpperBound)
19515+
return SDValue();
19516+
19517+
// Any value of z in [LowerBound, UpperBound] is valid. Prefer no extra
19518+
// shift if possible.
19519+
int Shift = std::min(std::max(/*prefer*/ 0, LowerBound), UpperBound);
19520+
19521+
// y = x / 2^(w + z)
19522+
int32_t RdsvlMul = (AbsConstValue >> (OperandShift + Shift)) *
19523+
(ConstMultiplier < 0 ? -1 : 1);
19524+
auto Rdsvl = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
19525+
DAG.getSignedConstant(RdsvlMul, DL, MVT::i32));
19526+
19527+
if (Shift == 0)
19528+
return Rdsvl;
19529+
return DAG.getNode(Shift < 0 ? ISD::SRL : ISD::SHL, DL, VT, Rdsvl,
19530+
DAG.getConstant(abs(Shift), DL, MVT::i32),
19531+
SDNodeFlags::Exact);
19532+
}
19533+
1947919534
// Combine v4i32 Mul(And(Srl(X, 15), 0x10001), 0xffff) -> v8i16 CMLTz
1948019535
// Same for other types with equivalent constants.
1948119536
static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) {
@@ -19604,6 +19659,9 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
1960419659
if (!isa<ConstantSDNode>(N1))
1960519660
return SDValue();
1960619661

19662+
if (SDValue Ext = performMulRdsvlCombine(N, DAG))
19663+
return Ext;
19664+
1960719665
ConstantSDNode *C = cast<ConstantSDNode>(N1);
1960819666
const APInt &ConstValue = C->getAPIntValue();
1960919667

llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,111 @@ define i64 @sme_cntsd_mul() {
8686
ret i64 %res
8787
}
8888

89-
declare i64 @llvm.aarch64.sme.cntsd()
89+
define i64 @sme_cntsb_mul_pos() {
90+
; CHECK-LABEL: sme_cntsb_mul_pos:
91+
; CHECK: // %bb.0:
92+
; CHECK-NEXT: rdsvl x8, #24
93+
; CHECK-NEXT: lsl x0, x8, #2
94+
; CHECK-NEXT: ret
95+
%v = call i64 @llvm.aarch64.sme.cntsd()
96+
%shl = shl nuw nsw i64 %v, 3
97+
%res = mul nuw nsw i64 %shl, 96
98+
ret i64 %res
99+
}
100+
101+
define i64 @sme_cntsh_mul_pos() {
102+
; CHECK-LABEL: sme_cntsh_mul_pos:
103+
; CHECK: // %bb.0:
104+
; CHECK-NEXT: rdsvl x8, #3
105+
; CHECK-NEXT: lsr x0, x8, #1
106+
; CHECK-NEXT: ret
107+
%v = call i64 @llvm.aarch64.sme.cntsd()
108+
%shl = shl nuw nsw i64 %v, 2
109+
%res = mul nuw nsw i64 %shl, 3
110+
ret i64 %res
111+
}
112+
113+
define i64 @sme_cntsw_mul_pos() {
114+
; CHECK-LABEL: sme_cntsw_mul_pos:
115+
; CHECK: // %bb.0:
116+
; CHECK-NEXT: rdsvl x8, #31
117+
; CHECK-NEXT: lsr x0, x8, #1
118+
; CHECK-NEXT: ret
119+
%v = call i64 @llvm.aarch64.sme.cntsd()
120+
%shl = shl nuw nsw i64 %v, 1
121+
%res = mul nuw nsw i64 %shl, 62
122+
ret i64 %res
123+
}
124+
125+
define i64 @sme_cntsd_mul_pos() {
126+
; CHECK-LABEL: sme_cntsd_mul_pos:
127+
; CHECK: // %bb.0:
128+
; CHECK-NEXT: rdsvl x8, #31
129+
; CHECK-NEXT: lsl x0, x8, #2
130+
; CHECK-NEXT: ret
131+
%v = call i64 @llvm.aarch64.sme.cntsd()
132+
%res = mul nuw nsw i64 %v, 992
133+
ret i64 %res
134+
}
135+
136+
define i64 @sme_cntsb_mul_neg() {
137+
; CHECK-LABEL: sme_cntsb_mul_neg:
138+
; CHECK: // %bb.0:
139+
; CHECK-NEXT: rdsvl x8, #-24
140+
; CHECK-NEXT: lsl x0, x8, #2
141+
; CHECK-NEXT: ret
142+
%v = call i64 @llvm.aarch64.sme.cntsd()
143+
%shl = shl nuw nsw i64 %v, 3
144+
%res = mul nuw nsw i64 %shl, -96
145+
ret i64 %res
146+
}
147+
148+
define i64 @sme_cntsh_mul_neg() {
149+
; CHECK-LABEL: sme_cntsh_mul_neg:
150+
; CHECK: // %bb.0:
151+
; CHECK-NEXT: rdsvl x8, #-3
152+
; CHECK-NEXT: lsr x0, x8, #1
153+
; CHECK-NEXT: ret
154+
%v = call i64 @llvm.aarch64.sme.cntsd()
155+
%shl = shl nuw nsw i64 %v, 2
156+
%res = mul nuw nsw i64 %shl, -3
157+
ret i64 %res
158+
}
159+
160+
define i64 @sme_cntsw_mul_neg() {
161+
; CHECK-LABEL: sme_cntsw_mul_neg:
162+
; CHECK: // %bb.0:
163+
; CHECK-NEXT: rdsvl x8, #-31
164+
; CHECK-NEXT: lsl x0, x8, #3
165+
; CHECK-NEXT: ret
166+
%v = call i64 @llvm.aarch64.sme.cntsd()
167+
%shl = shl nuw nsw i64 %v, 1
168+
%res = mul nuw nsw i64 %shl, -992
169+
ret i64 %res
170+
}
171+
172+
define i64 @sme_cntsd_mul_neg() {
173+
; CHECK-LABEL: sme_cntsd_mul_neg:
174+
; CHECK: // %bb.0:
175+
; CHECK-NEXT: rdsvl x8, #-3
176+
; CHECK-NEXT: lsr x0, x8, #3
177+
; CHECK-NEXT: ret
178+
%v = call i64 @llvm.aarch64.sme.cntsd()
179+
%res = mul nuw nsw i64 %v, -3
180+
ret i64 %res
181+
}
182+
183+
; Negative test for optimization failure
184+
define i64 @sme_cntsd_mul_fail() {
185+
; CHECK-LABEL: sme_cntsd_mul_fail:
186+
; CHECK: // %bb.0:
187+
; CHECK-NEXT: rdsvl x8, #1
188+
; CHECK-NEXT: mov w9, #993 // =0x3e1
189+
; CHECK-NEXT: lsr x8, x8, #3
190+
; CHECK-NEXT: mul x0, x8, x9
191+
; CHECK-NEXT: ret
192+
%v = call i64 @llvm.aarch64.sme.cntsd()
193+
%res = mul nuw nsw i64 %v, 993
194+
ret i64 %res
195+
}
196+

0 commit comments

Comments
 (0)