Skip to content

Commit 52434bc

Browse files
[AArch64] Add custom lowering of nxv32i1 get.active.lane.mask nodes
performActiveLaneMaskCombine already tries to combine a single get.active.lane.mask where the low and high halves of the result are extracted into a single whilelo which operates on a predicate pair. If the get.active.lane.mask node requires splitting, multiple nodes are created with saturating adds to increment the starting index. We cannot combine these into a single whilelo_x2 at this point unless we know the add will not overflow. This patch adds custom lowering for the node if the return type is nxv32xi1, as this can be replaced with a whilelo_x2 using legal types. Anything wider than nxv32i1 will still require splitting first.
1 parent b61144b commit 52434bc

File tree

3 files changed

+181
-1
lines changed

3 files changed

+181
-1
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,6 +1501,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15011501
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Legal);
15021502
}
15031503

1504+
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, MVT::nxv32i1, Custom);
1505+
15041506
for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
15051507
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
15061508
}
@@ -27328,6 +27330,29 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults(
2732827330
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
2732927331
}
2733027332

27333+
void AArch64TargetLowering::ReplaceGetActiveLaneMaskResults(
27334+
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
27335+
if (!Subtarget->hasSVE2p1())
27336+
return;
27337+
27338+
SDLoc DL(N);
27339+
SDValue Idx = N->getOperand(0);
27340+
SDValue TC = N->getOperand(1);
27341+
if (Idx.getValueType() != MVT::i64) {
27342+
Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
27343+
TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
27344+
}
27345+
27346+
SDValue ID =
27347+
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
27348+
EVT HalfVT = N->getValueType(0).getHalfNumVectorElementsVT(*DAG.getContext());
27349+
auto WideMask =
27350+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {HalfVT, HalfVT}, {ID, Idx, TC});
27351+
27352+
Results.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0),
27353+
{WideMask.getValue(0), WideMask.getValue(1)}));
27354+
}
27355+
2733127356
// Create an even/odd pair of X registers holding integer value V.
2733227357
static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
2733327358
SDLoc dl(V.getNode());
@@ -27714,6 +27739,9 @@ void AArch64TargetLowering::ReplaceNodeResults(
2771427739
// CONCAT_VECTORS -- but delegate to common code for result type
2771527740
// legalisation
2771627741
return;
27742+
case ISD::GET_ACTIVE_LANE_MASK:
27743+
ReplaceGetActiveLaneMaskResults(N, Results, DAG);
27744+
return;
2771727745
case ISD::INTRINSIC_WO_CHAIN: {
2771827746
EVT VT = N->getValueType(0);
2771927747

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,9 @@ class AArch64TargetLowering : public TargetLowering {
13181318
void ReplaceExtractSubVectorResults(SDNode *N,
13191319
SmallVectorImpl<SDValue> &Results,
13201320
SelectionDAG &DAG) const;
1321+
void ReplaceGetActiveLaneMaskResults(SDNode *N,
1322+
SmallVectorImpl<SDValue> &Results,
1323+
SelectionDAG &DAG) const;
13211324

13221325
bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override;
13231326

llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ define void @test_partial_extract(i64 %i, i64 %n) #0 {
111111
ret void
112112
}
113113

114-
;; Negative test for when extracting a fixed-length vector.
114+
; Negative test for when extracting a fixed-length vector.
115115
define void @test_fixed_extract(i64 %i, i64 %n) #0 {
116116
; CHECK-SVE-LABEL: test_fixed_extract:
117117
; CHECK-SVE: // %bb.0:
@@ -151,6 +151,155 @@ define void @test_fixed_extract(i64 %i, i64 %n) #0 {
151151
ret void
152152
}
153153

154+
; Illegal Types
155+
156+
define void @test_2x16bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
157+
; CHECK-SVE-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count:
158+
; CHECK-SVE: // %bb.0:
159+
; CHECK-SVE-NEXT: rdvl x8, #1
160+
; CHECK-SVE-NEXT: adds w8, w0, w8
161+
; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
162+
; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
163+
; CHECK-SVE-NEXT: whilelo p1.b, w8, w1
164+
; CHECK-SVE-NEXT: b use
165+
;
166+
; CHECK-SVE2p1-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count:
167+
; CHECK-SVE2p1: // %bb.0:
168+
; CHECK-SVE2p1-NEXT: mov w8, w1
169+
; CHECK-SVE2p1-NEXT: mov w9, w0
170+
; CHECK-SVE2p1-NEXT: whilelo { p0.b, p1.b }, x9, x8
171+
; CHECK-SVE2p1-NEXT: b use
172+
%r = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i32(i32 %i, i32 %n)
173+
%v0 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 0)
174+
%v1 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 16)
175+
tail call void @use(<vscale x 16 x i1> %v0, <vscale x 16 x i1> %v1)
176+
ret void
177+
}
178+
179+
define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
180+
; CHECK-SVE-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count:
181+
; CHECK-SVE: // %bb.0:
182+
; CHECK-SVE-NEXT: rdvl x8, #2
183+
; CHECK-SVE-NEXT: rdvl x9, #1
184+
; CHECK-SVE-NEXT: adds w8, w0, w8
185+
; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
186+
; CHECK-SVE-NEXT: adds w10, w8, w9
187+
; CHECK-SVE-NEXT: csinv w10, w10, wzr, lo
188+
; CHECK-SVE-NEXT: whilelo p3.b, w10, w1
189+
; CHECK-SVE-NEXT: adds w9, w0, w9
190+
; CHECK-SVE-NEXT: csinv w9, w9, wzr, lo
191+
; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
192+
; CHECK-SVE-NEXT: whilelo p1.b, w9, w1
193+
; CHECK-SVE-NEXT: whilelo p2.b, w8, w1
194+
; CHECK-SVE-NEXT: b use
195+
;
196+
; CHECK-SVE2p1-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count:
197+
; CHECK-SVE2p1: // %bb.0:
198+
; CHECK-SVE2p1-NEXT: rdvl x8, #2
199+
; CHECK-SVE2p1-NEXT: mov w9, w1
200+
; CHECK-SVE2p1-NEXT: mov w10, w0
201+
; CHECK-SVE2p1-NEXT: adds w8, w0, w8
202+
; CHECK-SVE2p1-NEXT: csinv w8, w8, wzr, lo
203+
; CHECK-SVE2p1-NEXT: whilelo { p0.b, p1.b }, x10, x9
204+
; CHECK-SVE2p1-NEXT: whilelo { p2.b, p3.b }, x8, x9
205+
; CHECK-SVE2p1-NEXT: b use
206+
%r = call <vscale x 64 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
207+
%v0 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 0)
208+
%v1 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 16)
209+
%v2 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 32)
210+
%v3 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 48)
211+
tail call void @use(<vscale x 16 x i1> %v0, <vscale x 16 x i1> %v1, <vscale x 16 x i1> %v2, <vscale x 16 x i1> %v3)
212+
ret void
213+
}
214+
215+
define void @test_2x16bit_mask_with_32bit_index_and_trip_count_ext8(i32 %i, i32 %n) #0 {
216+
; CHECK-SVE-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count_ext8:
217+
; CHECK-SVE: // %bb.0:
218+
; CHECK-SVE-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
219+
; CHECK-SVE-NEXT: rdvl x8, #1
220+
; CHECK-SVE-NEXT: adds w8, w0, w8
221+
; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
222+
; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
223+
; CHECK-SVE-NEXT: whilelo p4.b, w8, w1
224+
; CHECK-SVE-NEXT: punpklo p1.h, p0.b
225+
; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
226+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
227+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
228+
; CHECK-SVE-NEXT: punpklo p2.h, p3.b
229+
; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
230+
; CHECK-SVE-NEXT: bl use
231+
; CHECK-SVE-NEXT: punpklo p1.h, p4.b
232+
; CHECK-SVE-NEXT: punpkhi p3.h, p4.b
233+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
234+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
235+
; CHECK-SVE-NEXT: punpklo p2.h, p3.b
236+
; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
237+
; CHECK-SVE-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
238+
; CHECK-SVE-NEXT: b use
239+
;
240+
; CHECK-SVE2p1-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count_ext8:
241+
; CHECK-SVE2p1: // %bb.0:
242+
; CHECK-SVE2p1-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
243+
; CHECK-SVE2p1-NEXT: mov w8, w1
244+
; CHECK-SVE2p1-NEXT: mov w9, w0
245+
; CHECK-SVE2p1-NEXT: whilelo { p4.b, p5.b }, x9, x8
246+
; CHECK-SVE2p1-NEXT: punpklo p1.h, p4.b
247+
; CHECK-SVE2p1-NEXT: punpkhi p3.h, p4.b
248+
; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b
249+
; CHECK-SVE2p1-NEXT: punpkhi p1.h, p1.b
250+
; CHECK-SVE2p1-NEXT: punpklo p2.h, p3.b
251+
; CHECK-SVE2p1-NEXT: punpkhi p3.h, p3.b
252+
; CHECK-SVE2p1-NEXT: bl use
253+
; CHECK-SVE2p1-NEXT: punpklo p1.h, p5.b
254+
; CHECK-SVE2p1-NEXT: punpkhi p3.h, p5.b
255+
; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b
256+
; CHECK-SVE2p1-NEXT: punpkhi p1.h, p1.b
257+
; CHECK-SVE2p1-NEXT: punpklo p2.h, p3.b
258+
; CHECK-SVE2p1-NEXT: punpkhi p3.h, p3.b
259+
; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
260+
; CHECK-SVE2p1-NEXT: b use
261+
%r = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i32(i32 %i, i32 %n)
262+
%v0 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 0)
263+
%v1 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 4)
264+
%v2 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 8)
265+
%v3 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 12)
266+
tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1, <vscale x 4 x i1> %v2, <vscale x 4 x i1> %v3)
267+
%v4 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 16)
268+
%v5 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 20)
269+
%v6 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 24)
270+
%v7 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 28)
271+
tail call void @use(<vscale x 4 x i1> %v4, <vscale x 4 x i1> %v5, <vscale x 4 x i1> %v6, <vscale x 4 x i1> %v7)
272+
ret void
273+
}
274+
275+
; Negative test for when not extracting exactly two halves of the source vector
276+
define void @test_illegal_type_with_partial_extracts(i32 %i, i32 %n) #0 {
277+
; CHECK-SVE-LABEL: test_illegal_type_with_partial_extracts:
278+
; CHECK-SVE: // %bb.0:
279+
; CHECK-SVE-NEXT: rdvl x8, #1
280+
; CHECK-SVE-NEXT: adds w8, w0, w8
281+
; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
282+
; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
283+
; CHECK-SVE-NEXT: whilelo p1.b, w8, w1
284+
; CHECK-SVE-NEXT: punpkhi p0.h, p0.b
285+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
286+
; CHECK-SVE-NEXT: b use
287+
;
288+
; CHECK-SVE2p1-LABEL: test_illegal_type_with_partial_extracts:
289+
; CHECK-SVE2p1: // %bb.0:
290+
; CHECK-SVE2p1-NEXT: mov w8, w1
291+
; CHECK-SVE2p1-NEXT: mov w9, w0
292+
; CHECK-SVE2p1-NEXT: whilelo { p2.b, p3.b }, x9, x8
293+
; CHECK-SVE2p1-NEXT: punpkhi p0.h, p2.b
294+
; CHECK-SVE2p1-NEXT: punpkhi p1.h, p3.b
295+
; CHECK-SVE2p1-NEXT: b use
296+
%r = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i32(i32 %i, i32 %n)
297+
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 8)
298+
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 24)
299+
tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
300+
ret void
301+
}
302+
154303
declare void @use(...)
155304

156305
attributes #0 = { nounwind }

0 commit comments

Comments
 (0)