Skip to content

Commit 819e6b2

Browse files
authored
[InstSimplify] Consider vscale_range for get active lane mask (#160073)
Scalable get_active_lane_mask intrinsic calls can be simplified to i1 splat (ptrue) when its constant range is larger than or equal to the maximum possible number of elements, which can be inferred from vscale_range(x, y)
1 parent 66fd420 commit 819e6b2

File tree

4 files changed

+100
-49
lines changed

4 files changed

+100
-49
lines changed

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6514,10 +6514,27 @@ Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
65146514
const CallBase *Call) {
65156515
unsigned BitWidth = ReturnType->getScalarSizeInBits();
65166516
switch (IID) {
6517-
case Intrinsic::get_active_lane_mask:
6517+
case Intrinsic::get_active_lane_mask: {
65186518
if (match(Op1, m_Zero()))
65196519
return ConstantInt::getFalse(ReturnType);
6520+
6521+
const Function *F = Call->getFunction();
6522+
auto *ScalableTy = dyn_cast<ScalableVectorType>(ReturnType);
6523+
Attribute Attr = F->getFnAttribute(Attribute::VScaleRange);
6524+
if (ScalableTy && Attr.isValid()) {
6525+
std::optional<unsigned> VScaleMax = Attr.getVScaleRangeMax();
6526+
if (!VScaleMax)
6527+
break;
6528+
uint64_t MaxPossibleMaskElements =
6529+
(uint64_t)ScalableTy->getMinNumElements() * (*VScaleMax);
6530+
6531+
const APInt *Op1Val;
6532+
if (match(Op0, m_Zero()) && match(Op1, m_APInt(Op1Val)) &&
6533+
Op1Val->uge(MaxPossibleMaskElements))
6534+
return ConstantInt::getAllOnesValue(ReturnType);
6535+
}
65206536
break;
6537+
}
65216538
case Intrinsic::abs:
65226539
// abs(abs(x)) -> abs(x). We don't need to worry about the nsw arg here.
65236540
// It is always ok to pick the earlier abs. We'll just lose nsw if its only

llvm/test/Transforms/InstSimplify/get_active_lane_mask.ll

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,51 @@ define <vscale x 8 x i1> @foo_nxv8i1(i32 %a) {
1818
%mask = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1(i32 %a, i32 0)
1919
ret <vscale x 8 x i1> %mask
2020
}
21+
22+
define <vscale x 16 x i1> @foo_vscale_max_255() vscale_range(1,16) {
23+
; CHECK-LABEL: define <vscale x 16 x i1> @foo_vscale_max_255(
24+
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
25+
; CHECK-NEXT: [[MASK:%.*]] = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 0, i32 255)
26+
; CHECK-NEXT: ret <vscale x 16 x i1> [[MASK]]
27+
;
28+
%mask = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1(i32 0, i32 255)
29+
ret <vscale x 16 x i1> %mask
30+
}
31+
32+
define <vscale x 16 x i1> @foo_vscale_max_256() vscale_range(1,16) {
33+
; CHECK-LABEL: define <vscale x 16 x i1> @foo_vscale_max_256(
34+
; CHECK-SAME: ) #[[ATTR0]] {
35+
; CHECK-NEXT: ret <vscale x 16 x i1> splat (i1 true)
36+
;
37+
%mask = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1(i32 0, i32 256)
38+
ret <vscale x 16 x i1> %mask
39+
}
40+
41+
define <vscale x 2 x i1> @foo_vscale_max_nxv2i1_1_1_2() vscale_range(1,1) {
42+
; CHECK-LABEL: define <vscale x 2 x i1> @foo_vscale_max_nxv2i1_1_1_2(
43+
; CHECK-SAME: ) #[[ATTR1:[0-9]+]] {
44+
; CHECK-NEXT: ret <vscale x 2 x i1> splat (i1 true)
45+
;
46+
%mask = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1(i32 0, i32 2)
47+
ret <vscale x 2 x i1> %mask
48+
}
49+
50+
define <vscale x 4 x i1> @foo_vscale_max_nxv4i1_2_4_16() vscale_range(2,4) {
51+
; CHECK-LABEL: define <vscale x 4 x i1> @foo_vscale_max_nxv4i1_2_4_16(
52+
; CHECK-SAME: ) #[[ATTR2:[0-9]+]] {
53+
; CHECK-NEXT: ret <vscale x 4 x i1> splat (i1 true)
54+
;
55+
%mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1(i128 0, i128 16)
56+
ret <vscale x 4 x i1> %mask
57+
}
58+
59+
define <vscale x 4 x i1> @foo_vscale_max_nxv4i1_2_4_1_16() vscale_range(2,4) {
60+
; CHECK-LABEL: define <vscale x 4 x i1> @foo_vscale_max_nxv4i1_2_4_1_16(
61+
; CHECK-SAME: ) #[[ATTR2]] {
62+
; CHECK-NEXT: [[MASK:%.*]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i128(i128 1, i128 16)
63+
; CHECK-NEXT: ret <vscale x 4 x i1> [[MASK]]
64+
;
65+
%mask = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1(i128 1, i128 16)
66+
ret <vscale x 4 x i1> %mask
67+
}
68+

0 commit comments

Comments
 (0)