-
Couldn't load subscription status.
- Fork 15k
[AArch64][SVE2p1] Allow more uses of mask in performActiveLaneMaskCombine #159360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
7f042ec
64b6d61
70adaf7
083701d
d64e2d8
c077d34
a815152
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18774,7 +18774,7 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N, | |
| static SDValue | ||
| performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, | ||
| const AArch64Subtarget *ST) { | ||
| if (DCI.isBeforeLegalize()) | ||
| if (DCI.isBeforeLegalize() && !!DCI.isBeforeLegalizeOps()) | ||
| return SDValue(); | ||
|
|
||
| if (SDValue While = optimizeIncrementingWhile(N, DCI.DAG, /*IsSigned=*/false, | ||
|
|
@@ -18785,21 +18785,27 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, | |
| (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming()))) | ||
| return SDValue(); | ||
|
|
||
| unsigned NumUses = N->use_size(); | ||
| // Count the number of users which are extract_vectors | ||
| // The only other valid users for this combine are ptest_first | ||
| // and reinterpret_cast. | ||
|
||
| unsigned NumExts = count_if(N->users(), [](SDNode *Use) { | ||
| return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR; | ||
| }); | ||
|
|
||
| auto MaskEC = N->getValueType(0).getVectorElementCount(); | ||
| if (!MaskEC.isKnownMultipleOf(NumUses)) | ||
| if (NumExts == 0 || !MaskEC.isKnownMultipleOf(NumExts)) | ||
|
||
| return SDValue(); | ||
|
|
||
| ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses); | ||
| ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts); | ||
paulwalker-arm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (ExtMinEC.getKnownMinValue() < 2) | ||
| return SDValue(); | ||
|
|
||
| SmallVector<SDNode *> Extracts(NumUses, nullptr); | ||
| SmallVector<SDNode *> Extracts(NumExts, nullptr); | ||
| for (SDNode *Use : N->users()) { | ||
| if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR) | ||
| return SDValue(); | ||
| continue; | ||
|
|
||
| // Ensure the extract type is correct (e.g. if NumUses is 4 and | ||
| // Ensure the extract type is correct (e.g. if NumExts is 4 and | ||
| // the mask return type is nxv8i1, each extract should be nxv2i1. | ||
| if (Use->getValueType(0).getVectorElementCount() != ExtMinEC) | ||
| return SDValue(); | ||
|
|
@@ -18832,20 +18838,23 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, | |
| DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC}); | ||
| DCI.CombineTo(Extracts[0], R.getValue(0)); | ||
| DCI.CombineTo(Extracts[1], R.getValue(1)); | ||
paulwalker-arm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| SmallVector<SDValue> Results = {R.getValue(0), R.getValue(1)}; | ||
|
|
||
|
||
| if (NumUses == 2) | ||
| return SDValue(N, 0); | ||
| if (NumExts == 2) | ||
| return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Results); | ||
|
|
||
| auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2); | ||
| for (unsigned I = 2; I < NumUses; I += 2) { | ||
| for (unsigned I = 2; I < NumExts; I += 2) { | ||
| // After the first whilelo_x2, we need to increment the starting value. | ||
| Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts); | ||
| R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC}); | ||
| DCI.CombineTo(Extracts[I], R.getValue(0)); | ||
| DCI.CombineTo(Extracts[I + 1], R.getValue(1)); | ||
| Results.push_back(R.getValue(0)); | ||
| Results.push_back(R.getValue(1)); | ||
| } | ||
|
|
||
| return SDValue(N, 0); | ||
| return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Results); | ||
| } | ||
|
|
||
| // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -310,6 +310,187 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) # | |
| ret void | ||
| } | ||
|
|
||
| ; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first. | ||
|
|
||
| define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) { | ||
| ; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest: | ||
| ; CHECK-SVE: // %bb.0: // %entry | ||
| ; CHECK-SVE-NEXT: whilelo p1.b, x0, x1 | ||
| ; CHECK-SVE-NEXT: b.pl .LBB11_2 | ||
| ; CHECK-SVE-NEXT: // %bb.1: // %if.then | ||
| ; CHECK-SVE-NEXT: punpklo p0.h, p1.b | ||
| ; CHECK-SVE-NEXT: punpkhi p1.h, p1.b | ||
| ; CHECK-SVE-NEXT: b use | ||
| ; CHECK-SVE-NEXT: .LBB11_2: // %if.end | ||
| ; CHECK-SVE-NEXT: ret | ||
| ; | ||
| ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest: | ||
| ; CHECK-SVE2p1-SME2: // %bb.0: // %entry | ||
| ; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1 | ||
| ; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b | ||
| ; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b | ||
| ; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b | ||
| ; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2 | ||
| ; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then | ||
| ; CHECK-SVE2p1-SME2-NEXT: b use | ||
| ; CHECK-SVE2p1-SME2-NEXT: .LBB11_2: // %if.end | ||
| ; CHECK-SVE2p1-SME2-NEXT: ret | ||
| entry: | ||
| %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n) | ||
| %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0) | ||
| %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8) | ||
| %elt0 = extractelement <vscale x 16 x i1> %r, i32 0 | ||
| br i1 %elt0, label %if.then, label %if.end | ||
|
|
||
| if.then: | ||
| tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1) | ||
| br label %if.end | ||
|
|
||
| if.end: | ||
| ret void | ||
| } | ||
|
|
||
| ; Extra use of the get_active_lane_mask from an extractelement, which is | ||
| ; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1. | ||
|
|
||
| define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's worth adding a similar test for the |
||
| ; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts: | ||
| ; CHECK-SVE: // %bb.0: // %entry | ||
| ; CHECK-SVE-NEXT: whilelo p1.h, x0, x1 | ||
| ; CHECK-SVE-NEXT: b.pl .LBB12_2 | ||
| ; CHECK-SVE-NEXT: // %bb.1: // %if.then | ||
| ; CHECK-SVE-NEXT: punpklo p0.h, p1.b | ||
| ; CHECK-SVE-NEXT: punpkhi p1.h, p1.b | ||
| ; CHECK-SVE-NEXT: b use | ||
| ; CHECK-SVE-NEXT: .LBB12_2: // %if.end | ||
| ; CHECK-SVE-NEXT: ret | ||
| ; | ||
| ; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts: | ||
| ; CHECK-SVE2p1-SME2: // %bb.0: // %entry | ||
| ; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1 | ||
| ; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h | ||
| ; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h | ||
| ; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b | ||
| ; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2 | ||
| ; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then | ||
| ; CHECK-SVE2p1-SME2-NEXT: b use | ||
| ; CHECK-SVE2p1-SME2-NEXT: .LBB12_2: // %if.end | ||
| ; CHECK-SVE2p1-SME2-NEXT: ret | ||
| entry: | ||
| %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n) | ||
| %v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0) | ||
| %v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4) | ||
| %elt0 = extractelement <vscale x 8 x i1> %r, i64 0 | ||
| br i1 %elt0, label %if.then, label %if.end | ||
|
|
||
| if.then: | ||
| tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1) | ||
| br label %if.end | ||
|
|
||
| if.end: | ||
| ret void | ||
| } | ||
|
|
||
| define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) { | ||
| ; CHECK-SVE-LABEL: test_4x4bit_mask_with_extracts_and_ptest: | ||
| ; CHECK-SVE: // %bb.0: // %entry | ||
| ; CHECK-SVE-NEXT: whilelo p0.b, x0, x1 | ||
| ; CHECK-SVE-NEXT: b.pl .LBB13_2 | ||
| ; CHECK-SVE-NEXT: // %bb.1: // %if.then | ||
| ; CHECK-SVE-NEXT: punpklo p1.h, p0.b | ||
| ; CHECK-SVE-NEXT: punpkhi p3.h, p0.b | ||
| ; CHECK-SVE-NEXT: punpklo p0.h, p1.b | ||
| ; CHECK-SVE-NEXT: punpkhi p1.h, p1.b | ||
| ; CHECK-SVE-NEXT: punpklo p2.h, p3.b | ||
| ; CHECK-SVE-NEXT: punpkhi p3.h, p3.b | ||
| ; CHECK-SVE-NEXT: b use | ||
| ; CHECK-SVE-NEXT: .LBB13_2: // %if.end | ||
| ; CHECK-SVE-NEXT: ret | ||
| ; | ||
| ; CHECK-SVE2p1-SME2-LABEL: test_4x4bit_mask_with_extracts_and_ptest: | ||
| ; CHECK-SVE2p1-SME2: // %bb.0: // %entry | ||
| ; CHECK-SVE2p1-SME2-NEXT: cnth x8 | ||
| ; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8 | ||
| ; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo | ||
| ; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1 | ||
| ; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1 | ||
| ; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h | ||
| ; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h | ||
| ; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b | ||
| ; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b | ||
| ; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b | ||
| ; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2 | ||
| ; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then | ||
| ; CHECK-SVE2p1-SME2-NEXT: b use | ||
| ; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end | ||
| ; CHECK-SVE2p1-SME2-NEXT: ret | ||
| entry: | ||
| %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n) | ||
| %v0 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0) | ||
| %v1 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 4) | ||
| %v2 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8) | ||
| %v3 = call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 12) | ||
| %elt0 = extractelement <vscale x 16 x i1> %r, i32 0 | ||
| br i1 %elt0, label %if.then, label %if.end | ||
|
|
||
| if.then: | ||
| tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1, <vscale x 4 x i1> %v2, <vscale x 4 x i1> %v3) | ||
| br label %if.end | ||
|
|
||
| if.end: | ||
| ret void | ||
| } | ||
|
|
||
| define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) { | ||
| ; CHECK-SVE-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts: | ||
| ; CHECK-SVE: // %bb.0: // %entry | ||
| ; CHECK-SVE-NEXT: whilelo p0.h, x0, x1 | ||
| ; CHECK-SVE-NEXT: b.pl .LBB14_2 | ||
| ; CHECK-SVE-NEXT: // %bb.1: // %if.then | ||
| ; CHECK-SVE-NEXT: punpklo p1.h, p0.b | ||
| ; CHECK-SVE-NEXT: punpkhi p3.h, p0.b | ||
| ; CHECK-SVE-NEXT: punpklo p0.h, p1.b | ||
| ; CHECK-SVE-NEXT: punpkhi p1.h, p1.b | ||
| ; CHECK-SVE-NEXT: punpklo p2.h, p3.b | ||
| ; CHECK-SVE-NEXT: punpkhi p3.h, p3.b | ||
| ; CHECK-SVE-NEXT: b use | ||
| ; CHECK-SVE-NEXT: .LBB14_2: // %if.end | ||
| ; CHECK-SVE-NEXT: ret | ||
| ; | ||
| ; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts: | ||
| ; CHECK-SVE2p1-SME2: // %bb.0: // %entry | ||
| ; CHECK-SVE2p1-SME2-NEXT: cntw x8 | ||
| ; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8 | ||
| ; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo | ||
| ; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1 | ||
| ; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1 | ||
| ; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s | ||
| ; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s | ||
| ; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h | ||
| ; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h | ||
| ; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b | ||
| ; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2 | ||
| ; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then | ||
| ; CHECK-SVE2p1-SME2-NEXT: b use | ||
| ; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end | ||
| ; CHECK-SVE2p1-SME2-NEXT: ret | ||
| entry: | ||
| %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i32(i64 %i, i64 %n) | ||
| %v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0) | ||
| %v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2) | ||
| %v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4) | ||
| %v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6) | ||
| %elt0 = extractelement <vscale x 8 x i1> %r, i32 0 | ||
| br i1 %elt0, label %if.then, label %if.end | ||
|
|
||
| if.then: | ||
| tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v3) | ||
| br label %if.end | ||
|
|
||
| if.end: | ||
| ret void | ||
| } | ||
|
|
||
| declare void @use(...) | ||
|
|
||
| attributes #0 = { nounwind } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The
!!looks a little odd. Is it possible to just useDCI.isBeforeLegalizeOps()?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a mistake, it should be
!DCI.isBeforeLegalizeOps()There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is doing what you intend. Given
!DCI.isBeforeLegalizeOps())means AfterLegalizeVectorOps, you've effectively writtenif "before type legalisation" and "after vector ops legalisation", which is always going to be false because they are are opposite ends of selection.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've now removed this altogether, since finding that a concat with 4+ inputs will be split after this combine without changing the stage at which it applies (although I've also made changes to create multiple concat_vectors in the latest commit too).