Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20370,7 +20370,7 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
}

static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
AArch64CC::CondCode Cond);
AArch64CC::CondCode Cond, bool EmitCSel = true);

static bool isPredicateCCSettingOp(SDValue N) {
if ((N.getOpcode() == ISD::SETCC) ||
Expand Down Expand Up @@ -20495,6 +20495,7 @@ static SDValue
performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);

if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget))
return Res;
if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
Expand Down Expand Up @@ -22535,7 +22536,7 @@ static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC,
}

static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
AArch64CC::CondCode Cond) {
AArch64CC::CondCode Cond, bool EmitCSel) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();

SDLoc DL(Op);
Expand Down Expand Up @@ -22568,6 +22569,8 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,

// Set condition code (CC) flags.
SDValue Test = DAG.getNode(PTest, DL, MVT::i32, Pg, Op);
if (!EmitCSel)
return Test;

// Convert CC to integer based on requested condition.
// NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
Expand Down Expand Up @@ -27519,6 +27522,37 @@ static SDValue performMULLCombine(SDNode *N,
return SDValue();
}

static SDValue performPTestFirstCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
if (DCI.isBeforeLegalize())
return SDValue();

SDLoc DL(N);
auto Mask = N->getOperand(0);
auto Pred = N->getOperand(1);

if (Mask->getOpcode() == AArch64ISD::REINTERPRET_CAST)
Mask = Mask->getOperand(0);

if (Pred->getOpcode() == AArch64ISD::REINTERPRET_CAST)
Pred = Pred->getOperand(0);

if (Pred->getValueType(0).getVectorElementType() != MVT::i1 ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need the element type check because AArch64ISD::REINTERPRET_CAST is only allowed to change the element count.

!isAllActivePredicate(DAG, Mask))
return SDValue();

if (Pred->getOpcode() == ISD::CONCAT_VECTORS) {
Pred = Pred->getOperand(0);
SDValue Mask = DAG.getSplatVector(Pred->getValueType(0), DL,
DAG.getAllOnesConstant(DL, MVT::i64));
Copy link
Collaborator

@paulwalker-arm paulwalker-arm Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've already proven the existing Mask does what you need so should be reusable. Perhaps it's worth adding an isLane0KnownOne helper function, then you don't need to strip reinterpret from Mask and can then reuse it directly.

return getPTest(DAG, N->getValueType(0), Mask, Pred,
AArch64CC::FIRST_ACTIVE, /* EmitCSel */ false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt you're using much in getPTest to warrant the EmitCSel change. Having already proven the DAG is testing the first lane of the predicate[1] there's no complexity related to "hidden lanes" and so this can be simplified to:

Pred = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, nxv16i1, Pred)
Mask = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, nxv16i1, Mask)
return DAG.getNode(AArch64ISD::PTEST_FIRST, DL, N->getValueType(0), Pred, Mask);

[1] Is worth adding commentary detailing the "we know we are testing the first lane" requirement because that's key to the combine.

}

return SDValue();
}

static SDValue
performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
Expand Down Expand Up @@ -27875,6 +27909,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case AArch64ISD::UMULL:
case AArch64ISD::PMULL:
return performMULLCombine(N, DCI, DAG);
case AArch64ISD::PTEST_FIRST:
return performPTestFirstCombine(N, DCI, DAG);
case ISD::INTRINSIC_VOID:
case ISD::INTRINSIC_W_CHAIN:
switch (N->getConstantOperandVal(1)) {
Expand Down
20 changes: 2 additions & 18 deletions llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,6 @@ define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.b
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.b, p0.b, p1.b
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
Expand Down Expand Up @@ -368,9 +365,6 @@ define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
; CHECK-SVE2p1-SME2: // %bb.0: // %entry
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: ptrue p2.h
; CHECK-SVE2p1-SME2-NEXT: uzp1 p3.h, p0.h, p1.h
; CHECK-SVE2p1-SME2-NEXT: ptest p2, p3.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: b use
Expand Down Expand Up @@ -413,14 +407,9 @@ define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p0.h, p1.h
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p2.h, p3.h
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p4.b, p5.b
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.b
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB13_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
Expand Down Expand Up @@ -463,14 +452,9 @@ define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p0.s, p1.s
; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p2.s, p3.s
; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p4.h, p5.h
; CHECK-SVE2p1-SME2-NEXT: ptrue p5.h
; CHECK-SVE2p1-SME2-NEXT: ptest p5, p4.b
; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB14_2
; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
; CHECK-SVE2p1-SME2-NEXT: b use
; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end
; CHECK-SVE2p1-SME2-NEXT: ret
Expand Down