Skip to content
Merged
4 changes: 4 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,10 @@ def vector_insert_subvec : SDNode<"ISD::INSERT_SUBVECTOR",
def extract_subvector : SDNode<"ISD::EXTRACT_SUBVECTOR", SDTSubVecExtract, []>;
def insert_subvector : SDNode<"ISD::INSERT_SUBVECTOR", SDTSubVecInsert, []>;

def find_last_active
: SDNode<"ISD::VECTOR_FIND_LAST_ACTIVE",
SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<1>]>, []>;

// Nodes for intrinsics, you should use the intrinsic itself and let tblgen use
// these internally. Don't reference these directly.
def intrinsic_void : SDNode<"ISD::INTRINSIC_VOID",
Expand Down
68 changes: 68 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
}
for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1})
setOperationAction(ISD::VECTOR_FIND_LAST_ACTIVE, VT, Legal);
}

if (Subtarget->isSVEorStreamingSVEAvailable()) {
Expand Down Expand Up @@ -19730,6 +19732,33 @@ performLastTrueTestVectorCombine(SDNode *N,
return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::LAST_ACTIVE);
}

static SDValue
performExtractLastActiveCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
SelectionDAG &DAG = DCI.DAG;
SDValue Vec = N->getOperand(0);
SDValue Idx = N->getOperand(1);

if (DCI.isBeforeLegalize() || Idx.getOpcode() != ISD::VECTOR_FIND_LAST_ACTIVE)
return SDValue();

// Only legal for 8, 16, 32, and 64 bit element types.
EVT EltVT = Vec.getValueType().getVectorElementType();
if (!is_contained(
ArrayRef({MVT::i8, MVT::i16, MVT::i32, MVT::i64, MVT::f32, MVT::f64}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about f16 and bf16?

EltVT.getSimpleVT().SimpleTy))
return SDValue();

SDValue Mask = Idx.getOperand(0);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isOperationLegal(ISD::VECTOR_FIND_LAST_ACTIVE, Mask.getValueType()))
return SDValue();

return DAG.getNode(AArch64ISD::LASTB, SDLoc(N), N->getValueType(0), Mask,
Vec);
}

static SDValue
performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
Expand All @@ -19738,6 +19767,8 @@ performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return Res;
if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
return Res;
if (SDValue Res = performExtractLastActiveCombine(N, DCI, Subtarget))
return Res;

SelectionDAG &DAG = DCI.DAG;
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
Expand Down Expand Up @@ -24852,6 +24883,39 @@ static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) {
}
}

static SDValue foldCSELofLASTB(SDNode *Op, SelectionDAG &DAG) {
AArch64CC::CondCode OpCC =
static_cast<AArch64CC::CondCode>(Op->getConstantOperandVal(2));

if (OpCC != AArch64CC::NE)
return SDValue();

SDValue PTest = Op->getOperand(3);
if (PTest.getOpcode() != AArch64ISD::PTEST_ANY)
return SDValue();

SDValue TruePred = PTest.getOperand(0);
SDValue AnyPred = PTest.getOperand(1);

if (TruePred.getOpcode() == AArch64ISD::REINTERPRET_CAST)
TruePred = TruePred.getOperand(0);

if (AnyPred.getOpcode() == AArch64ISD::REINTERPRET_CAST)
AnyPred = AnyPred.getOperand(0);

if (TruePred != AnyPred && TruePred.getOpcode() != AArch64ISD::PTRUE)
return SDValue();

SDValue LastB = Op->getOperand(0);
SDValue Default = Op->getOperand(1);

if (LastB.getOpcode() != AArch64ISD::LASTB || LastB.getOperand(0) != AnyPred)
return SDValue();

return DAG.getNode(AArch64ISD::CLASTB_N, SDLoc(Op), Op->getValueType(0),
AnyPred, Default, LastB.getOperand(1));
}

// Optimize CSEL instructions
static SDValue performCSELCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
Expand Down Expand Up @@ -24897,6 +24961,10 @@ static SDValue performCSELCombine(SDNode *N,
}
}

// CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z
if (SDValue CondLast = foldCSELofLASTB(N, DAG))
return CondLast;

return performCONDCombine(N, DCI, DAG, 2, 3);
}

Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,20 @@ let Predicates = [HasSVE_or_SME] in {
def : Pat<(i64 (vector_extract nxv2i64:$vec, VectorIndexD:$index)),
(UMOVvi64 (v2i64 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexD:$index)>;

// Find index of last active lane. This is a fallback in case we miss the
// opportunity to fold into a lastb or clastb directly.
Comment on lines +3382 to +3383
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these fallback patterns tested in the final patch?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it would be good to have some tests for these.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly it's pretty difficult to do this once the combines have been added. I don't see a global switch to disable combining, just the target-indepedent combines. We do check the optimization level in a few places in AArch64ISelLowering, but mostly for TLI methods for IR-level decisions. Deliberately turning off (c)lastb pattern matching at O0 feels odd. Adding a new switch just for this feature also feels excessive.

I could potentially add a globalisel-based test, though I'm not sure how much code that requires. We've added a few new ISD nodes recently, and none have added support in globalisel.

I guess this is mostly due to it being hard to just create a selectiondag without IR and run selection over it.

def : Pat<(i64(find_last_active nxv16i1:$P1)),
(INSERT_SUBREG(IMPLICIT_DEF), (LASTB_RPZ_B $P1, (INDEX_II_B 0, 1)),
sub_32)>;
def : Pat<(i64(find_last_active nxv8i1:$P1)),
(INSERT_SUBREG(IMPLICIT_DEF), (LASTB_RPZ_H $P1, (INDEX_II_H 0, 1)),
sub_32)>;
def : Pat<(i64(find_last_active nxv4i1:$P1)),
(INSERT_SUBREG(IMPLICIT_DEF), (LASTB_RPZ_S $P1, (INDEX_II_S 0, 1)),
sub_32)>;
def : Pat<(i64(find_last_active nxv2i1:$P1)), (LASTB_RPZ_D $P1, (INDEX_II_D 0,
1))>;

// Move element from the bottom 128-bits of a scalable vector to a single-element vector.
// Alternative case where insertelement is just scalar_to_vector rather than vector_insert.
def : Pat<(v1f64 (scalar_to_vector
Expand Down
95 changes: 23 additions & 72 deletions llvm/test/CodeGen/AArch64/vector-extract-last-active.ll
Original file line number Diff line number Diff line change
Expand Up @@ -293,17 +293,7 @@ define double @extract_last_double(<2 x double> %data, <2 x i64> %mask, double %
define i8 @extract_last_i8_scalable(<vscale x 16 x i8> %data, <vscale x 16 x i1> %mask, i8 %passthru) #0 {
; CHECK-LABEL: extract_last_i8_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: index z1.b, #0, #1
; CHECK-NEXT: mov z2.b, #0 // =0x0
; CHECK-NEXT: ptrue p1.b
; CHECK-NEXT: sel z1.b, p0, z1.b, z2.b
; CHECK-NEXT: umaxv b1, p1, z1.b
; CHECK-NEXT: fmov w8, s1
; CHECK-NEXT: and x8, x8, #0xff
; CHECK-NEXT: whilels p1.b, xzr, x8
; CHECK-NEXT: ptest p0, p0.b
; CHECK-NEXT: lastb w8, p1, z0.b
; CHECK-NEXT: csel w0, w8, w0, ne
; CHECK-NEXT: clastb w0, p0, w0, z0.b
; CHECK-NEXT: ret
%res = call i8 @llvm.experimental.vector.extract.last.active.nxv16i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %mask, i8 %passthru)
ret i8 %res
Expand All @@ -312,17 +302,7 @@ define i8 @extract_last_i8_scalable(<vscale x 16 x i8> %data, <vscale x 16 x i1>
define i16 @extract_last_i16_scalable(<vscale x 8 x i16> %data, <vscale x 8 x i1> %mask, i16 %passthru) #0 {
; CHECK-LABEL: extract_last_i16_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: index z1.h, #0, #1
; CHECK-NEXT: mov z2.h, #0 // =0x0
; CHECK-NEXT: ptrue p1.h
; CHECK-NEXT: sel z1.h, p0, z1.h, z2.h
; CHECK-NEXT: umaxv h1, p1, z1.h
; CHECK-NEXT: fmov w8, s1
; CHECK-NEXT: and x8, x8, #0xffff
; CHECK-NEXT: whilels p2.h, xzr, x8
; CHECK-NEXT: ptest p1, p0.b
; CHECK-NEXT: lastb w8, p2, z0.h
; CHECK-NEXT: csel w0, w8, w0, ne
; CHECK-NEXT: clastb w0, p0, w0, z0.h
; CHECK-NEXT: ret
%res = call i16 @llvm.experimental.vector.extract.last.active.nxv8i16(<vscale x 8 x i16> %data, <vscale x 8 x i1> %mask, i16 %passthru)
ret i16 %res
Expand All @@ -331,17 +311,7 @@ define i16 @extract_last_i16_scalable(<vscale x 8 x i16> %data, <vscale x 8 x i1
define i32 @extract_last_i32_scalable(<vscale x 4 x i32> %data, <vscale x 4 x i1> %mask, i32 %passthru) #0 {
; CHECK-LABEL: extract_last_i32_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: index z1.s, #0, #1
; CHECK-NEXT: mov z2.s, #0 // =0x0
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s
; CHECK-NEXT: umaxv s1, p1, z1.s
; CHECK-NEXT: fmov w8, s1
; CHECK-NEXT: mov w8, w8
; CHECK-NEXT: whilels p2.s, xzr, x8
; CHECK-NEXT: ptest p1, p0.b
; CHECK-NEXT: lastb w8, p2, z0.s
; CHECK-NEXT: csel w0, w8, w0, ne
; CHECK-NEXT: clastb w0, p0, w0, z0.s
; CHECK-NEXT: ret
%res = call i32 @llvm.experimental.vector.extract.last.active.nxv4i32(<vscale x 4 x i32> %data, <vscale x 4 x i1> %mask, i32 %passthru)
ret i32 %res
Expand All @@ -350,16 +320,7 @@ define i32 @extract_last_i32_scalable(<vscale x 4 x i32> %data, <vscale x 4 x i1
define i64 @extract_last_i64_scalable(<vscale x 2 x i64> %data, <vscale x 2 x i1> %mask, i64 %passthru) #0 {
; CHECK-LABEL: extract_last_i64_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: index z1.d, #0, #1
; CHECK-NEXT: mov z2.d, #0 // =0x0
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: sel z1.d, p0, z1.d, z2.d
; CHECK-NEXT: umaxv d1, p1, z1.d
; CHECK-NEXT: fmov x8, d1
; CHECK-NEXT: whilels p2.d, xzr, x8
; CHECK-NEXT: ptest p1, p0.b
; CHECK-NEXT: lastb x8, p2, z0.d
; CHECK-NEXT: csel x0, x8, x0, ne
; CHECK-NEXT: clastb x0, p0, x0, z0.d
; CHECK-NEXT: ret
%res = call i64 @llvm.experimental.vector.extract.last.active.nxv2i64(<vscale x 2 x i64> %data, <vscale x 2 x i1> %mask, i64 %passthru)
ret i64 %res
Expand All @@ -368,17 +329,8 @@ define i64 @extract_last_i64_scalable(<vscale x 2 x i64> %data, <vscale x 2 x i1
define float @extract_last_float_scalable(<vscale x 4 x float> %data, <vscale x 4 x i1> %mask, float %passthru) #0 {
; CHECK-LABEL: extract_last_float_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: index z2.s, #0, #1
; CHECK-NEXT: mov z3.s, #0 // =0x0
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: sel z2.s, p0, z2.s, z3.s
; CHECK-NEXT: umaxv s2, p1, z2.s
; CHECK-NEXT: fmov w8, s2
; CHECK-NEXT: mov w8, w8
; CHECK-NEXT: whilels p2.s, xzr, x8
; CHECK-NEXT: ptest p1, p0.b
; CHECK-NEXT: lastb s0, p2, z0.s
; CHECK-NEXT: fcsel s0, s0, s1, ne
; CHECK-NEXT: clastb s1, p0, s1, z0.s
; CHECK-NEXT: fmov s0, s1
; CHECK-NEXT: ret
%res = call float @llvm.experimental.vector.extract.last.active.nxv4f32(<vscale x 4 x float> %data, <vscale x 4 x i1> %mask, float %passthru)
ret float %res
Expand All @@ -387,16 +339,8 @@ define float @extract_last_float_scalable(<vscale x 4 x float> %data, <vscale x
define double @extract_last_double_scalable(<vscale x 2 x double> %data, <vscale x 2 x i1> %mask, double %passthru) #0 {
; CHECK-LABEL: extract_last_double_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: index z2.d, #0, #1
; CHECK-NEXT: mov z3.d, #0 // =0x0
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: sel z2.d, p0, z2.d, z3.d
; CHECK-NEXT: umaxv d2, p1, z2.d
; CHECK-NEXT: fmov x8, d2
; CHECK-NEXT: whilels p2.d, xzr, x8
; CHECK-NEXT: ptest p1, p0.b
; CHECK-NEXT: lastb d0, p2, z0.d
; CHECK-NEXT: fcsel d0, d0, d1, ne
; CHECK-NEXT: clastb d1, p0, d1, z0.d
; CHECK-NEXT: fmov d0, d1
; CHECK-NEXT: ret
%res = call double @llvm.experimental.vector.extract.last.active.nxv2f64(<vscale x 2 x double> %data, <vscale x 2 x i1> %mask, double %passthru)
ret double %res
Expand All @@ -406,20 +350,26 @@ define double @extract_last_double_scalable(<vscale x 2 x double> %data, <vscale
define i8 @extract_last_i8_scalable_poison_passthru(<vscale x 16 x i8> %data, <vscale x 16 x i1> %mask) #0 {
; CHECK-LABEL: extract_last_i8_scalable_poison_passthru:
; CHECK: // %bb.0:
; CHECK-NEXT: index z1.b, #0, #1
; CHECK-NEXT: mov z2.b, #0 // =0x0
; CHECK-NEXT: sel z1.b, p0, z1.b, z2.b
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: umaxv b1, p0, z1.b
; CHECK-NEXT: fmov w8, s1
; CHECK-NEXT: and x8, x8, #0xff
; CHECK-NEXT: whilels p0.b, xzr, x8
; CHECK-NEXT: lastb w0, p0, z0.b
; CHECK-NEXT: ret
%res = call i8 @llvm.experimental.vector.extract.last.active.nxv16i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %mask, i8 poison)
ret i8 %res
}

;; (c)lastb doesn't exist for predicate types; check we get functional codegen
define i1 @extract_last_i1_scalable(<vscale x 16 x i1> %data, <vscale x 16 x i1> %mask) #0 {
; CHECK-LABEL: extract_last_i1_scalable:
; CHECK: // %bb.0:
; CHECK-NEXT: mov z0.b, p0/z, #1 // =0x1
; CHECK-NEXT: ptest p1, p1.b
; CHECK-NEXT: cset w9, ne
; CHECK-NEXT: lastb w8, p1, z0.b
; CHECK-NEXT: and w0, w9, w8
; CHECK-NEXT: ret
%res = call i1 @llvm.experimental.vector.extract.last.active.nxv16i1(<vscale x 16 x i1> %data, <vscale x 16 x i1> %mask, i1 false)
ret i1 %res
}

declare i8 @llvm.experimental.vector.extract.last.active.v16i8(<16 x i8>, <16 x i1>, i8)
declare i16 @llvm.experimental.vector.extract.last.active.v8i16(<8 x i16>, <8 x i1>, i16)
declare i32 @llvm.experimental.vector.extract.last.active.v4i32(<4 x i32>, <4 x i1>, i32)
Expand All @@ -432,5 +382,6 @@ declare i32 @llvm.experimental.vector.extract.last.active.nxv4i32(<vscale x 4 x
declare i64 @llvm.experimental.vector.extract.last.active.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i1>, i64)
declare float @llvm.experimental.vector.extract.last.active.nxv4f32(<vscale x 4 x float>, <vscale x 4 x i1>, float)
declare double @llvm.experimental.vector.extract.last.active.nxv2f64(<vscale x 2 x double>, <vscale x 2 x i1>, double)
declare i1 @llvm.experimental.vector.extract.last.active.nxv16i1(<vscale x 16 x i1>, <vscale x 16 x i1>, i1)

attributes #0 = { "target-features"="+sve" vscale_range(1, 16) }