Skip to content
31 changes: 20 additions & 11 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18774,7 +18774,7 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
static SDValue
performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *ST) {
if (DCI.isBeforeLegalize())
if (DCI.isBeforeLegalize() && !!DCI.isBeforeLegalizeOps())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The !! looks a little odd. Is it possible to just use DCI.isBeforeLegalizeOps()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a mistake, it should be !DCI.isBeforeLegalizeOps()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is doing what you intend. Given !DCI.isBeforeLegalizeOps()) means AfterLegalizeVectorOps, you've effectively written if "before type legalisation" and "after vector ops legalisation", which is always going to be false because they are are opposite ends of selection.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now removed this altogether, since finding that a concat with 4+ inputs will be split after this combine without changing the stage at which it applies (although I've also made changes to create multiple concat_vectors in the latest commit too).

return SDValue();

if (SDValue While = optimizeIncrementingWhile(N, DCI.DAG, /*IsSigned=*/false,
Expand All @@ -18785,21 +18785,27 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
(!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
return SDValue();

unsigned NumUses = N->use_size();
// Count the number of users which are extract_vectors
// The only other valid users for this combine are ptest_first
// and reinterpret_cast.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This last bit can be removed because the ptest_first restriction no longer applies.

unsigned NumExts = count_if(N->users(), [](SDNode *Use) {
return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR;
});

auto MaskEC = N->getValueType(0).getVectorElementCount();
if (!MaskEC.isKnownMultipleOf(NumUses))
if (NumExts == 0 || !MaskEC.isKnownMultipleOf(NumExts))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to move the zero check into isKnownMultipleOf?

return SDValue();

ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
if (ExtMinEC.getKnownMinValue() < 2)
return SDValue();

SmallVector<SDNode *> Extracts(NumUses, nullptr);
SmallVector<SDNode *> Extracts(NumExts, nullptr);
for (SDNode *Use : N->users()) {
if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
return SDValue();
continue;

// Ensure the extract type is correct (e.g. if NumUses is 4 and
// Ensure the extract type is correct (e.g. if NumExts is 4 and
// the mask return type is nxv8i1, each extract should be nxv2i1.
if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
return SDValue();
Expand Down Expand Up @@ -18832,20 +18838,23 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));
SmallVector<SDValue> Results = {R.getValue(0), R.getValue(1)};

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is wrapping the operands in {} necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I've removed them in both places

if (NumUses == 2)
return SDValue(N, 0);
if (NumExts == 2)
return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Results);

auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
for (unsigned I = 2; I < NumUses; I += 2) {
for (unsigned I = 2; I < NumExts; I += 2) {
// After the first whilelo_x2, we need to increment the starting value.
Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
DCI.CombineTo(Extracts[I], R.getValue(0));
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
Results.push_back(R.getValue(0));
Results.push_back(R.getValue(1));
}

return SDValue(N, 0);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Results);
}

// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
Expand Down
181 changes: 181 additions & 0 deletions llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,187 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #
ret void
}

; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first.

define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
; CHECK-SVE: // %bb.0: // %entry
; CHECK-SVE-NEXT: whilelo p1.b, x0, x1
; CHECK-SVE-NEXT: b.pl .LBB11_2
; CHECK-SVE-NEXT: // %bb.1: // %if.then
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: b use
; CHECK-SVE-NEXT: .LBB11_2: // %if.end
; CHECK-SVE-NEXT: ret
;
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB11_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
entry:
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
%elt0 = extractelement <vscale x 16 x i1> %r, i32 0
br i1 %elt0, label %if.then, label %if.end

if.then:
tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
br label %if.end

if.end:
ret void
}

; Extra use of the get_active_lane_mask from an extractelement, which is
; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.

define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth adding a similar test for the NumExts != 2 case, if only to see if that better exposes the issues I believe exist in the PR as it stands today.

; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE: // %bb.0: // %entry
; CHECK-SVE-NEXT: whilelo p1.h, x0, x1
; CHECK-SVE-NEXT: b.pl .LBB12_2
; CHECK-SVE-NEXT: // %bb.1: // %if.then
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: b use
; CHECK-SVE-NEXT: .LBB12_2: // %if.end
; CHECK-SVE-NEXT: ret
;
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB12_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
entry:
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
%v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0)
%v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4)
%elt0 = extractelement <vscale x 8 x i1> %r, i64 0
br i1 %elt0, label %if.then, label %if.end

if.then:
tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1)
br label %if.end

if.end:
ret void
}

define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
; CHECK-SVE: // %bb.0: // %entry
; CHECK-SVE-NEXT: whilelo p0.b, x0, x1
; CHECK-SVE-NEXT: b.pl .LBB13_2
; CHECK-SVE-NEXT: // %bb.1: // %if.then
; CHECK-SVE-NEXT: punpklo p1.h, p0.b
; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: punpklo p2.h, p3.b
; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
; CHECK-SVE-NEXT: b use
; CHECK-SVE-NEXT: .LBB13_2: // %if.end
; CHECK-SVE-NEXT: ret
;
; CHECK-SVE2p1-SME2-LABEL: test_4x4bit_mask_with_extracts_and_ptest:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: cnth x8
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
entry:
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
%v0 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
%v1 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 4)
%v2 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
%v3 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 12)
%elt0 = extractelement <vscale x 16 x i1> %r, i32 0
br i1 %elt0, label %if.then, label %if.end

if.then:
tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1, <vscale x 4 x i1> %v2, <vscale x 4 x i1> %v3)
br label %if.end

if.end:
ret void
}

define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
; CHECK-SVE-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE: // %bb.0: // %entry
; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
; CHECK-SVE-NEXT: b.pl .LBB14_2
; CHECK-SVE-NEXT: // %bb.1: // %if.then
; CHECK-SVE-NEXT: punpklo p1.h, p0.b
; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: punpklo p2.h, p3.b
; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
; CHECK-SVE-NEXT: b use
; CHECK-SVE-NEXT: .LBB14_2: // %if.end
; CHECK-SVE-NEXT: ret
;
; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: cntw x8
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
entry:
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i32(i64 %i, i64 %n)
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
%v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
%v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
%elt0 = extractelement <vscale x 8 x i1> %r, i32 0
br i1 %elt0, label %if.then, label %if.end

if.then:
tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v3)
br label %if.end

if.end:
ret void
}

declare void @use(...)

attributes #0 = { nounwind }