Skip to content

Commit cf50bbf

Browse files
[AArch64][SVE2p1] Allow more uses of mask in performActiveLaneMaskCombine (#159360)
The combine replaces a get_active_lane_mask used by two extract subvectors with a single paired whilelo intrinsic. When the instruction is used for control flow in a vector loop, an additional extract of element 0 may introduce other uses of the intrinsic such as ptest and reinterpret cast, which is currently not supported. This patch changes performActiveLaneMaskCombine to count the number of extract subvectors using the mask instead of the total number of uses, and returns the concatenated results of get_active_lane_mask.
1 parent 343476e commit cf50bbf

File tree

4 files changed

+210
-16
lines changed

4 files changed

+210
-16
lines changed

llvm/include/llvm/Support/TypeSize.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ template <typename LeafTy, typename ValueTy> class FixedOrScalableQuantity {
179179
/// This function tells the caller whether the element count is known at
180180
/// compile time to be a multiple of the scalar value RHS.
181181
constexpr bool isKnownMultipleOf(ScalarTy RHS) const {
182-
return getKnownMinValue() % RHS == 0;
182+
return RHS != 0 && getKnownMinValue() % RHS == 0;
183183
}
184184

185185
/// Returns whether or not the callee is known to be a multiple of RHS.
@@ -191,7 +191,8 @@ template <typename LeafTy, typename ValueTy> class FixedOrScalableQuantity {
191191
// x % y == 0 !=> x % (vscale * y) == 0
192192
if (!isScalable() && RHS.isScalable())
193193
return false;
194-
return getKnownMinValue() % RHS.getKnownMinValue() == 0;
194+
return RHS.getKnownMinValue() != 0 &&
195+
getKnownMinValue() % RHS.getKnownMinValue() == 0;
195196
}
196197

197198
// Return the minimum value with the assumption that the count is exact.

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18867,21 +18867,25 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1886718867
(!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
1886818868
return SDValue();
1886918869

18870-
unsigned NumUses = N->use_size();
18870+
// Count the number of users which are extract_vectors.
18871+
unsigned NumExts = count_if(N->users(), [](SDNode *Use) {
18872+
return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR;
18873+
});
18874+
1887118875
auto MaskEC = N->getValueType(0).getVectorElementCount();
18872-
if (!MaskEC.isKnownMultipleOf(NumUses))
18876+
if (!MaskEC.isKnownMultipleOf(NumExts))
1887318877
return SDValue();
1887418878

18875-
ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
18879+
ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
1887618880
if (ExtMinEC.getKnownMinValue() < 2)
1887718881
return SDValue();
1887818882

18879-
SmallVector<SDNode *> Extracts(NumUses, nullptr);
18883+
SmallVector<SDNode *> Extracts(NumExts, nullptr);
1888018884
for (SDNode *Use : N->users()) {
1888118885
if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
18882-
return SDValue();
18886+
continue;
1888318887

18884-
// Ensure the extract type is correct (e.g. if NumUses is 4 and
18888+
// Ensure the extract type is correct (e.g. if NumExts is 4 and
1888518889
// the mask return type is nxv8i1, each extract should be nxv2i1.
1888618890
if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
1888718891
return SDValue();
@@ -18902,32 +18906,39 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1890218906

1890318907
SDValue Idx = N->getOperand(0);
1890418908
SDValue TC = N->getOperand(1);
18905-
EVT OpVT = Idx.getValueType();
18906-
if (OpVT != MVT::i64) {
18909+
if (Idx.getValueType() != MVT::i64) {
1890718910
Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
1890818911
TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
1890918912
}
1891018913

1891118914
// Create the whilelo_x2 intrinsics from each pair of extracts
1891218915
EVT ExtVT = Extracts[0]->getValueType(0);
18916+
EVT DoubleExtVT = ExtVT.getDoubleNumVectorElementsVT(*DAG.getContext());
1891318917
auto R =
1891418918
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
1891518919
DCI.CombineTo(Extracts[0], R.getValue(0));
1891618920
DCI.CombineTo(Extracts[1], R.getValue(1));
18921+
SmallVector<SDValue> Concats = {DAG.getNode(
18922+
ISD::CONCAT_VECTORS, DL, DoubleExtVT, R.getValue(0), R.getValue(1))};
1891718923

18918-
if (NumUses == 2)
18919-
return SDValue(N, 0);
18924+
if (NumExts == 2) {
18925+
assert(N->getValueType(0) == DoubleExtVT);
18926+
return Concats[0];
18927+
}
1892018928

18921-
auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
18922-
for (unsigned I = 2; I < NumUses; I += 2) {
18929+
auto Elts =
18930+
DAG.getElementCount(DL, MVT::i64, ExtVT.getVectorElementCount() * 2);
18931+
for (unsigned I = 2; I < NumExts; I += 2) {
1892318932
// After the first whilelo_x2, we need to increment the starting value.
18924-
Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
18933+
Idx = DAG.getNode(ISD::UADDSAT, DL, MVT::i64, Idx, Elts);
1892518934
R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
1892618935
DCI.CombineTo(Extracts[I], R.getValue(0));
1892718936
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
18937+
Concats.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, DoubleExtVT,
18938+
R.getValue(0), R.getValue(1)));
1892818939
}
1892918940

18930-
return SDValue(N, 0);
18941+
return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Concats);
1893118942
}
1893218943

1893318944
// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce

llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,187 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #
310310
ret void
311311
}
312312

313+
; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first.
314+
315+
define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
316+
; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
317+
; CHECK-SVE: // %bb.0: // %entry
318+
; CHECK-SVE-NEXT: whilelo p1.b, x0, x1
319+
; CHECK-SVE-NEXT: b.pl .LBB11_2
320+
; CHECK-SVE-NEXT: // %bb.1: // %if.then
321+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
322+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
323+
; CHECK-SVE-NEXT: b use
324+
; CHECK-SVE-NEXT: .LBB11_2: // %if.end
325+
; CHECK-SVE-NEXT: ret
326+
;
327+
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
328+
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
329+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
330+
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b
331+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b
332+
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
333+
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
334+
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
335+
; CHECK-SVE2p1-SME2-NEXT: b use
336+
; CHECK-SVE2p1-SME2-NEXT: .LBB11_2: // %if.end
337+
; CHECK-SVE2p1-SME2-NEXT: ret
338+
entry:
339+
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
340+
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
341+
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
342+
%elt0 = extractelement <vscale x 16 x i1> %r, i32 0
343+
br i1 %elt0, label %if.then, label %if.end
344+
345+
if.then:
346+
tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
347+
br label %if.end
348+
349+
if.end:
350+
ret void
351+
}
352+
353+
; Extra use of the get_active_lane_mask from an extractelement, which is
354+
; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.
355+
356+
define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
357+
; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
358+
; CHECK-SVE: // %bb.0: // %entry
359+
; CHECK-SVE-NEXT: whilelo p1.h, x0, x1
360+
; CHECK-SVE-NEXT: b.pl .LBB12_2
361+
; CHECK-SVE-NEXT: // %bb.1: // %if.then
362+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
363+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
364+
; CHECK-SVE-NEXT: b use
365+
; CHECK-SVE-NEXT: .LBB12_2: // %if.end
366+
; CHECK-SVE-NEXT: ret
367+
;
368+
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
369+
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
370+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
371+
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h
372+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h
373+
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
374+
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
375+
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
376+
; CHECK-SVE2p1-SME2-NEXT: b use
377+
; CHECK-SVE2p1-SME2-NEXT: .LBB12_2: // %if.end
378+
; CHECK-SVE2p1-SME2-NEXT: ret
379+
entry:
380+
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
381+
%v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0)
382+
%v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4)
383+
%elt0 = extractelement <vscale x 8 x i1> %r, i64 0
384+
br i1 %elt0, label %if.then, label %if.end
385+
386+
if.then:
387+
tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1)
388+
br label %if.end
389+
390+
if.end:
391+
ret void
392+
}
393+
394+
define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
395+
; CHECK-SVE-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
396+
; CHECK-SVE: // %bb.0: // %entry
397+
; CHECK-SVE-NEXT: whilelo p0.b, x0, x1
398+
; CHECK-SVE-NEXT: b.pl .LBB13_2
399+
; CHECK-SVE-NEXT: // %bb.1: // %if.then
400+
; CHECK-SVE-NEXT: punpklo p1.h, p0.b
401+
; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
402+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
403+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
404+
; CHECK-SVE-NEXT: punpklo p2.h, p3.b
405+
; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
406+
; CHECK-SVE-NEXT: b use
407+
; CHECK-SVE-NEXT: .LBB13_2: // %if.end
408+
; CHECK-SVE-NEXT: ret
409+
;
410+
; CHECK-SVE2p1-SME2-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
411+
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
412+
; CHECK-SVE2p1-SME2-NEXT: cnth x8
413+
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
414+
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
415+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
416+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
417+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h
418+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h
419+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b
420+
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b
421+
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
422+
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
423+
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
424+
; CHECK-SVE2p1-SME2-NEXT: b use
425+
; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
426+
; CHECK-SVE2p1-SME2-NEXT: ret
427+
entry:
428+
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
429+
%v0 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
430+
%v1 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 4)
431+
%v2 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
432+
%v3 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 12)
433+
%elt0 = extractelement <vscale x 16 x i1> %r, i32 0
434+
br i1 %elt0, label %if.then, label %if.end
435+
436+
if.then:
437+
tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1, <vscale x 4 x i1> %v2, <vscale x 4 x i1> %v3)
438+
br label %if.end
439+
440+
if.end:
441+
ret void
442+
}
443+
444+
define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
445+
; CHECK-SVE-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
446+
; CHECK-SVE: // %bb.0: // %entry
447+
; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
448+
; CHECK-SVE-NEXT: b.pl .LBB14_2
449+
; CHECK-SVE-NEXT: // %bb.1: // %if.then
450+
; CHECK-SVE-NEXT: punpklo p1.h, p0.b
451+
; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
452+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
453+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
454+
; CHECK-SVE-NEXT: punpklo p2.h, p3.b
455+
; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
456+
; CHECK-SVE-NEXT: b use
457+
; CHECK-SVE-NEXT: .LBB14_2: // %if.end
458+
; CHECK-SVE-NEXT: ret
459+
;
460+
; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
461+
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
462+
; CHECK-SVE2p1-SME2-NEXT: cntw x8
463+
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
464+
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
465+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
466+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
467+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s
468+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s
469+
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h
470+
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h
471+
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
472+
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
473+
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
474+
; CHECK-SVE2p1-SME2-NEXT: b use
475+
; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
476+
; CHECK-SVE2p1-SME2-NEXT: ret
477+
entry:
478+
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i32(i64 %i, i64 %n)
479+
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
480+
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
481+
%v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
482+
%v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
483+
%elt0 = extractelement <vscale x 8 x i1> %r, i32 0
484+
br i1 %elt0, label %if.then, label %if.end
485+
486+
if.then:
487+
tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v3)
488+
br label %if.end
489+
490+
if.end:
491+
ret void
492+
}
493+
313494
declare void @use(...)
314495

315496
attributes #0 = { nounwind }

llvm/unittests/Support/TypeSizeTest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ static_assert(ElementCount::getFixed(8).divideCoefficientBy(2) ==
5858
static_assert(ElementCount::getFixed(8).multiplyCoefficientBy(3) ==
5959
ElementCount::getFixed(24));
6060
static_assert(ElementCount::getFixed(8).isKnownMultipleOf(2));
61+
static_assert(!ElementCount::getFixed(8).isKnownMultipleOf(0));
6162

6263
constexpr TypeSize TSFixed0 = TypeSize::getFixed(0);
6364
constexpr TypeSize TSFixed1 = TypeSize::getFixed(1);

0 commit comments

Comments
 (0)