Skip to content

Commit 942cd73

Browse files
kmclaughlin-armmahesh-attarde
authored andcommitted
[AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) -> PTEST_FIRST(PTRUE, A) (llvm#161384)
When the input to ptest_first is a vector concat and the mask is all active, performPTestFirstCombine returns a ptest_first using the first operand of the concat, looking through any reinterpret casts. This allows optimizePTestInstr to later remove the ptest when the first operand is a flag setting instruction such as whilelo.
1 parent b60e9c8 commit 942cd73

File tree

3 files changed

+52
-18
lines changed

3 files changed

+52
-18
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27234,6 +27234,21 @@ static bool isLanes1toNKnownZero(SDValue Op) {
2723427234
}
2723527235
}
2723627236

27237+
// Return true if the vector operation can guarantee that the first lane of its
27238+
// result is active.
27239+
static bool isLane0KnownActive(SDValue Op) {
27240+
switch (Op.getOpcode()) {
27241+
default:
27242+
return false;
27243+
case AArch64ISD::REINTERPRET_CAST:
27244+
return isLane0KnownActive(Op->getOperand(0));
27245+
case ISD::SPLAT_VECTOR:
27246+
return isOneConstant(Op.getOperand(0));
27247+
case AArch64ISD::PTRUE:
27248+
return Op.getConstantOperandVal(0) == AArch64SVEPredPattern::all;
27249+
};
27250+
}
27251+
2723727252
static SDValue removeRedundantInsertVectorElt(SDNode *N) {
2723827253
assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT && "Unexpected node!");
2723927254
SDValue InsertVec = N->getOperand(0);
@@ -27519,6 +27534,32 @@ static SDValue performMULLCombine(SDNode *N,
2751927534
return SDValue();
2752027535
}
2752127536

27537+
static SDValue performPTestFirstCombine(SDNode *N,
27538+
TargetLowering::DAGCombinerInfo &DCI,
27539+
SelectionDAG &DAG) {
27540+
if (DCI.isBeforeLegalize())
27541+
return SDValue();
27542+
27543+
SDLoc DL(N);
27544+
auto Mask = N->getOperand(0);
27545+
auto Pred = N->getOperand(1);
27546+
27547+
if (!isLane0KnownActive(Mask))
27548+
return SDValue();
27549+
27550+
if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
27551+
Pred = Pred->getOperand(0);
27552+
27553+
if (Pred->getOpcode() == ISD::CONCAT_VECTORS) {
27554+
Pred = Pred->getOperand(0);
27555+
Pred = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pred);
27556+
return DAG.getNode(AArch64ISD::PTEST_FIRST, DL, N->getValueType(0), Mask,
27557+
Pred);
27558+
}
27559+
27560+
return SDValue();
27561+
}
27562+
2752227563
static SDValue
2752327564
performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
2752427565
SelectionDAG &DAG) {
@@ -27875,6 +27916,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2787527916
case AArch64ISD::UMULL:
2787627917
case AArch64ISD::PMULL:
2787727918
return performMULLCombine(N, DCI, DAG);
27919+
case AArch64ISD::PTEST_FIRST:
27920+
return performPTestFirstCombine(N, DCI, DAG);
2787827921
case ISD::INTRINSIC_VOID:
2787927922
case ISD::INTRINSIC_W_CHAIN:
2788027923
switch (N->getConstantOperandVal(1)) {

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,6 +1503,13 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
15031503
getElementSizeForOpcode(PredOpcode))
15041504
return PredOpcode;
15051505

1506+
// For PTEST_FIRST(PTRUE_ALL, WHILE), the PTEST_FIRST is redundant since
1507+
// WHILEcc performs an implicit PTEST with an all active mask, setting
1508+
// the N flag as the PTEST_FIRST would.
1509+
if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST &&
1510+
isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31)
1511+
return PredOpcode;
1512+
15061513
return {};
15071514
}
15081515

llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,6 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
327327
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
328328
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
329329
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
330-
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b
331-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b
332-
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
333330
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
334331
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
335332
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -368,9 +365,6 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
368365
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
369366
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
370367
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
371-
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h
372-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h
373-
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
374368
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
375369
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
376370
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -413,14 +407,9 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
413407
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
414408
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
415409
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
416-
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
417-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h
418-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h
419-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b
420-
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b
421-
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
422410
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
423411
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
412+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
424413
; CHECK-SVE2p1-SME2-NEXT: b use
425414
; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
426415
; CHECK-SVE2p1-SME2-NEXT: ret
@@ -463,14 +452,9 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
463452
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
464453
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
465454
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
466-
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
467-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s
468-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s
469-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h
470-
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h
471-
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
472455
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
473456
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
457+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
474458
; CHECK-SVE2p1-SME2-NEXT: b use
475459
; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
476460
; CHECK-SVE2p1-SME2-NEXT: ret

0 commit comments

Comments
 (0)