Skip to content

Commit a3167b5

Browse files
[AArch64] Combine PTEST_FIRST(PTRUE, CONCAT(A, B)) -> PTEST_FIRST(PTRUE, A)
When input to a ptest_first is a vector concat and the mask is all active, performPTestFirstCombine returns a ptest_first using the first operand of the concat, looking through any reinterpret casts added by getPTest. This allows optimizePTestInstr to later remove the ptest when the first operand is a flag setting instruction such as whilelo.
1 parent cf50bbf commit a3167b5

File tree

2 files changed

+40
-20
lines changed

2 files changed

+40
-20
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20370,7 +20370,7 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
2037020370
}
2037120371

2037220372
static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
20373-
AArch64CC::CondCode Cond);
20373+
AArch64CC::CondCode Cond, bool EmitCSel = true);
2037420374

2037520375
static bool isPredicateCCSettingOp(SDValue N) {
2037620376
if ((N.getOpcode() == ISD::SETCC) ||
@@ -20495,6 +20495,7 @@ static SDValue
2049520495
performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
2049620496
const AArch64Subtarget *Subtarget) {
2049720497
assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
20498+
2049820499
if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget))
2049920500
return Res;
2050020501
if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
@@ -22535,7 +22536,7 @@ static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC,
2253522536
}
2253622537

2253722538
static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
22538-
AArch64CC::CondCode Cond) {
22539+
AArch64CC::CondCode Cond, bool EmitCSel) {
2253922540
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2254022541

2254122542
SDLoc DL(Op);
@@ -22568,6 +22569,8 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
2256822569

2256922570
// Set condition code (CC) flags.
2257022571
SDValue Test = DAG.getNode(PTest, DL, MVT::i32, Pg, Op);
22572+
if (!EmitCSel)
22573+
return Test;
2257122574

2257222575
// Convert CC to integer based on requested condition.
2257322576
// NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
@@ -27519,6 +27522,37 @@ static SDValue performMULLCombine(SDNode *N,
2751927522
return SDValue();
2752027523
}
2752127524

27525+
static SDValue performPTestFirstCombine(SDNode *N,
27526+
TargetLowering::DAGCombinerInfo &DCI,
27527+
SelectionDAG &DAG) {
27528+
if (DCI.isBeforeLegalize())
27529+
return SDValue();
27530+
27531+
SDLoc DL(N);
27532+
auto Mask = N->getOperand(0);
27533+
auto Pred = N->getOperand(1);
27534+
27535+
if (Mask->getOpcode() == AArch64ISD::REINTERPRET_CAST)
27536+
Mask = Mask->getOperand(0);
27537+
27538+
if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
27539+
Pred = Pred->getOperand(0);
27540+
27541+
if (Pred->getValueType(0).getVectorElementType() != MVT::i1 ||
27542+
!isAllActivePredicate(DAG, Mask))
27543+
return SDValue();
27544+
27545+
if (Pred->getOpcode() == ISD::CONCAT_VECTORS) {
27546+
Pred = Pred->getOperand(0);
27547+
SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL,
27548+
DAG.getAllOnesConstant(DL, MVT::i64));
27549+
return getPTest(DAG, N->getValueType(0), Mask, Pred,
27550+
AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false);
27551+
}
27552+
27553+
return SDValue();
27554+
}
27555+
2752227556
static SDValue
2752327557
performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
2752427558
SelectionDAG &DAG) {
@@ -27875,6 +27909,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2787527909
case AArch64ISD::UMULL:
2787627910
case AArch64ISD::PMULL:
2787727911
return performMULLCombine(N, DCI, DAG);
27912+
case AArch64ISD::PTEST_FIRST:
27913+
return performPTestFirstCombine(N, DCI, DAG);
2787827914
case ISD::INTRINSIC_VOID:
2787927915
case ISD::INTRINSIC_W_CHAIN:
2788027916
switch (N->getConstantOperandVal(1)) {

llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,6 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
327327
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
328328
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
329329
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
330-
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b
331-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b
332-
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
333330
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
334331
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
335332
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -368,9 +365,6 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
368365
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
369366
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
370367
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
371-
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h
372-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h
373-
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
374368
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
375369
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
376370
; CHECK-SVE2p1-SME2-NEXT: b use
@@ -413,14 +407,9 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
413407
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
414408
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
415409
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
416-
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
417-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h
418-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h
419-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b
420-
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b
421-
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
422410
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
423411
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
412+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
424413
; CHECK-SVE2p1-SME2-NEXT: b use
425414
; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
426415
; CHECK-SVE2p1-SME2-NEXT: ret
@@ -463,14 +452,9 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
463452
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
464453
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
465454
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
466-
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
467-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s
468-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s
469-
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h
470-
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h
471-
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
472455
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
473456
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
457+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
474458
; CHECK-SVE2p1-SME2-NEXT: b use
475459
; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
476460
; CHECK-SVE2p1-SME2-NEXT: ret

0 commit comments

Comments
 (0)