@@ -2059,6 +2059,19 @@ bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
20592059 VT != MVT::v4i1 && VT != MVT::v2i1;
20602060}
20612061
2062+ bool AArch64TargetLowering::shouldExpandVectorMatch(EVT VT,
2063+ unsigned SearchSize) const {
2064+ // MATCH is SVE2 and only available in non-streaming mode.
2065+ if (!Subtarget->hasSVE2() || !Subtarget->isSVEAvailable())
2066+ return true;
2067+ // Furthermore, we can only use it for 8-bit or 16-bit elements.
2068+ if (VT == MVT::nxv8i16 || VT == MVT::v8i16)
2069+ return SearchSize != 8;
2070+ if (VT == MVT::nxv16i8 || VT == MVT::v16i8 || VT == MVT::v8i8)
2071+ return SearchSize != 8 && SearchSize != 16;
2072+ return true;
2073+ }
2074+
20622075void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
20632076 assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
20642077
@@ -5780,6 +5793,72 @@ SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
57805793 DAG.getTargetConstant(ImmAddend, DL, MVT::i32)});
57815794}
57825795
5796+ SDValue LowerVectorMatch(SDValue Op, SelectionDAG &DAG) {
5797+ SDLoc dl(Op);
5798+ SDValue ID =
5799+ DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
5800+
5801+ auto Op1 = Op.getOperand(1);
5802+ auto Op2 = Op.getOperand(2);
5803+ auto Mask = Op.getOperand(3);
5804+
5805+ EVT Op1VT = Op1.getValueType();
5806+ EVT Op2VT = Op2.getValueType();
5807+ EVT ResVT = Op.getValueType();
5808+
5809+ assert((Op1VT.getVectorElementType() == MVT::i8 ||
5810+ Op1VT.getVectorElementType() == MVT::i16) &&
5811+ "Expected 8-bit or 16-bit characters.");
5812+
5813+ // Scalable vector type used to wrap operands.
5814+ // A single container is enough for both operands because ultimately the
5815+ // operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
5816+ EVT OpContainerVT = Op1VT.isScalableVector()
5817+ ? Op1VT
5818+ : getContainerForFixedLengthVector(DAG, Op1VT);
5819+
5820+ if (Op2VT.is128BitVector()) {
5821+ // If Op2 is a full 128-bit vector, wrap it trivially in a scalable vector.
5822+ Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
5823+ // Further, if the result is scalable, broadcast Op2 to a full SVE register.
5824+ if (ResVT.isScalableVector())
5825+ Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
5826+ DAG.getTargetConstant(0, dl, MVT::i64));
5827+ } else {
5828+ // If Op2 is not a full 128-bit vector, we always need to broadcast it.
5829+ unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
5830+ MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
5831+ EVT Op2PromotedVT = getPackedSVEVectorVT(Op2IntVT);
5832+ Op2 = DAG.getBitcast(MVT::getVectorVT(Op2IntVT, 1), Op2);
5833+ Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT, Op2,
5834+ DAG.getConstant(0, dl, MVT::i64));
5835+ Op2 = DAG.getSplatVector(Op2PromotedVT, dl, Op2);
5836+ Op2 = DAG.getBitcast(OpContainerVT, Op2);
5837+ }
5838+
5839+ // If the result is scalable, we just need to carry out the MATCH.
5840+ if (ResVT.isScalableVector())
5841+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1, Op2);
5842+
5843+ // If the result is fixed, we can still use MATCH but we need to wrap the
5844+ // first operand and the mask in scalable vectors before doing so.
5845+
5846+ // Wrap the operands.
5847+ Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
5848+ Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, Op1VT, Mask);
5849+ Mask = convertFixedMaskToScalableVector(Mask, DAG);
5850+
5851+ // Carry out the match.
5852+ SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Mask.getValueType(),
5853+ ID, Mask, Op1, Op2);
5854+
5855+ // Extract and promote the match result (nxv16i1/nxv8i1) to ResVT
5856+ // (v16i8/v8i8).
5857+ Match = DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match);
5858+ Match = convertFromScalableVector(DAG, Op1VT, Match);
5859+ return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
5860+ }
5861+
57835862SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
57845863 SelectionDAG &DAG) const {
57855864 unsigned IntNo = Op.getConstantOperandVal(1);
@@ -6383,6 +6462,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
63836462 DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
63846463 return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
63856464 }
6465+ case Intrinsic::experimental_vector_match: {
6466+ return LowerVectorMatch(Op, DAG);
6467+ }
63866468 }
63876469}
63886470
@@ -27153,6 +27235,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
2715327235 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
2715427236 return;
2715527237 }
27238+ case Intrinsic::experimental_vector_match:
2715627239 case Intrinsic::get_active_lane_mask: {
2715727240 if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
2715827241 return;
0 commit comments