diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index f0a703be35207..e87de292639dd 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1501,6 +1501,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Legal); } + if (Subtarget->hasSVE2p1() || + (Subtarget->hasSME2() && Subtarget->isStreaming())) + setOperationAction(ISD::GET_ACTIVE_LANE_MASK, MVT::nxv32i1, Custom); + for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32}) setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom); } @@ -18165,7 +18169,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, /*IsEqual=*/false)) return While; - if (!ST->hasSVE2p1()) + if (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())) return SDValue(); if (!N->hasNUsesOfValue(2, 0)) @@ -27328,6 +27332,37 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults( Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half)); } +void AArch64TargetLowering::ReplaceGetActiveLaneMaskResults( + SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { + assert((Subtarget->hasSVE2p1() || + (Subtarget->hasSME2() && Subtarget->isStreaming())) && + "Custom lower of get.active.lane.mask missing required feature."); + + assert(N->getValueType(0) == MVT::nxv32i1 && + "Unexpected result type for get.active.lane.mask"); + + SDLoc DL(N); + SDValue Idx = N->getOperand(0); + SDValue TC = N->getOperand(1); + + assert(Idx.getValueType().getFixedSizeInBits() <= 64 && + "Unexpected operand type for get.active.lane.mask"); + + if (Idx.getValueType() != MVT::i64) { + Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx); + TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC); + } + + SDValue ID = + DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64); + EVT HalfVT = N->getValueType(0).getHalfNumVectorElementsVT(*DAG.getContext()); + auto WideMask = + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {HalfVT, HalfVT}, {ID, Idx, TC}); + + Results.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), + {WideMask.getValue(0), WideMask.getValue(1)})); +} + // Create an even/odd pair of X registers holding integer value V. static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) { SDLoc dl(V.getNode()); @@ -27714,6 +27749,9 @@ void AArch64TargetLowering::ReplaceNodeResults( // CONCAT_VECTORS -- but delegate to common code for result type // legalisation return; + case ISD::GET_ACTIVE_LANE_MASK: + ReplaceGetActiveLaneMaskResults(N, Results, DAG); + return; case ISD::INTRINSIC_WO_CHAIN: { EVT VT = N->getValueType(0); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index b59526bf01888..4c6358034af02 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1318,6 +1318,9 @@ class AArch64TargetLowering : public TargetLowering { void ReplaceExtractSubVectorResults(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const; + void ReplaceGetActiveLaneMaskResults(SDNode *N, + SmallVectorImpl &Results, + SelectionDAG &DAG) const; bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override; diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll index 2d84a69f3144e..c76b50d69b877 100644 --- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll +++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll @@ -1,6 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4 ; RUN: llc -mattr=+sve < %s | FileCheck %s -check-prefix CHECK-SVE -; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1 +; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1-SME2 -check-prefix CHECK-SVE2p1 +; RUN: llc -mattr=+sve -mattr=+sme2 -force-streaming < %s | FileCheck %s -check-prefix CHECK-SVE2p1-SME2 -check-prefix CHECK-SME2 target triple = "aarch64-linux" ; Test combining of getActiveLaneMask with a pair of extract_vector operations. @@ -13,12 +14,12 @@ define void @test_2x8bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 ; CHECK-SVE-NEXT: punpkhi p1.h, p1.b ; CHECK-SVE-NEXT: b use ; -; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count: -; CHECK-SVE2p1: // %bb.0: -; CHECK-SVE2p1-NEXT: mov w8, w1 -; CHECK-SVE2p1-NEXT: mov w9, w0 -; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x9, x8 -; CHECK-SVE2p1-NEXT: b use +; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count: +; CHECK-SVE2p1-SME2: // %bb.0: +; CHECK-SVE2p1-SME2-NEXT: mov w8, w1 +; CHECK-SVE2p1-SME2-NEXT: mov w9, w0 +; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x9, x8 +; CHECK-SVE2p1-SME2-NEXT: b use %r = call @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n) %v0 = call @llvm.vector.extract.nxv8i1.nxv16i1.i64( %r, i64 0) %v1 = call @llvm.vector.extract.nxv8i1.nxv16i1.i64( %r, i64 8) @@ -34,10 +35,10 @@ define void @test_2x8bit_mask_with_64bit_index_and_trip_count(i64 %i, i64 %n) #0 ; CHECK-SVE-NEXT: punpkhi p1.h, p1.b ; CHECK-SVE-NEXT: b use ; -; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count: -; CHECK-SVE2p1: // %bb.0: -; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x0, x1 -; CHECK-SVE2p1-NEXT: b use +; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count: +; CHECK-SVE2p1-SME2: // %bb.0: +; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1 +; CHECK-SVE2p1-SME2-NEXT: b use %r = call @llvm.get.active.lane.mask.nxv16i1.i64(i64 %i, i64 %n) %v0 = call @llvm.vector.extract.nxv8i1.nxv16i1.i64( %r, i64 0) %v1 = call @llvm.vector.extract.nxv8i1.nxv16i1.i64( %r, i64 8) @@ -53,12 +54,12 @@ define void @test_edge_case_2x1bit_mask(i64 %i, i64 %n) #0 { ; CHECK-SVE-NEXT: punpkhi p1.h, p1.b ; CHECK-SVE-NEXT: b use ; -; CHECK-SVE2p1-LABEL: test_edge_case_2x1bit_mask: -; CHECK-SVE2p1: // %bb.0: -; CHECK-SVE2p1-NEXT: whilelo p1.d, x0, x1 -; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b -; CHECK-SVE2p1-NEXT: punpkhi p1.h, p1.b -; CHECK-SVE2p1-NEXT: b use +; CHECK-SVE2p1-SME2-LABEL: test_edge_case_2x1bit_mask: +; CHECK-SVE2p1-SME2: // %bb.0: +; CHECK-SVE2p1-SME2-NEXT: whilelo p1.d, x0, x1 +; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b +; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b +; CHECK-SVE2p1-SME2-NEXT: b use %r = call @llvm.get.active.lane.mask.nxv2i1.i64(i64 %i, i64 %n) %v0 = call @llvm.vector.extract.nxv1i1.nxv2i1.i64( %r, i64 0) %v1 = call @llvm.vector.extract.nxv1i1.nxv2i1.i64( %r, i64 1) @@ -74,10 +75,10 @@ define void @test_boring_case_2x2bit_mask(i64 %i, i64 %n) #0 { ; CHECK-SVE-NEXT: punpkhi p1.h, p1.b ; CHECK-SVE-NEXT: b use ; -; CHECK-SVE2p1-LABEL: test_boring_case_2x2bit_mask: -; CHECK-SVE2p1: // %bb.0: -; CHECK-SVE2p1-NEXT: whilelo { p0.d, p1.d }, x0, x1 -; CHECK-SVE2p1-NEXT: b use +; CHECK-SVE2p1-SME2-LABEL: test_boring_case_2x2bit_mask: +; CHECK-SVE2p1-SME2: // %bb.0: +; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1 +; CHECK-SVE2p1-SME2-NEXT: b use %r = call @llvm.get.active.lane.mask.nxv4i1.i64(i64 %i, i64 %n) %v0 = call @llvm.vector.extract.nxv2i1.nxv4i1.i64( %r, i64 0) %v1 = call @llvm.vector.extract.nxv2i1.nxv4i1.i64( %r, i64 2) @@ -96,14 +97,14 @@ define void @test_partial_extract(i64 %i, i64 %n) #0 { ; CHECK-SVE-NEXT: punpklo p1.h, p2.b ; CHECK-SVE-NEXT: b use ; -; CHECK-SVE2p1-LABEL: test_partial_extract: -; CHECK-SVE2p1: // %bb.0: -; CHECK-SVE2p1-NEXT: whilelo p0.h, x0, x1 -; CHECK-SVE2p1-NEXT: punpklo p1.h, p0.b -; CHECK-SVE2p1-NEXT: punpkhi p2.h, p0.b -; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b -; CHECK-SVE2p1-NEXT: punpklo p1.h, p2.b -; CHECK-SVE2p1-NEXT: b use +; CHECK-SVE2p1-SME2-LABEL: test_partial_extract: +; CHECK-SVE2p1-SME2: // %bb.0: +; CHECK-SVE2p1-SME2-NEXT: whilelo p0.h, x0, x1 +; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b +; CHECK-SVE2p1-SME2-NEXT: punpkhi p2.h, p0.b +; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b +; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p2.b +; CHECK-SVE2p1-SME2-NEXT: b use %r = call @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n) %v0 = call @llvm.vector.extract.nxv2i1.nxv8i1.i64( %r, i64 0) %v1 = call @llvm.vector.extract.nxv2i1.nxv8i1.i64( %r, i64 4) @@ -111,7 +112,7 @@ define void @test_partial_extract(i64 %i, i64 %n) #0 { ret void } -;; Negative test for when extracting a fixed-length vector. +; Negative test for when extracting a fixed-length vector. define void @test_fixed_extract(i64 %i, i64 %n) #0 { ; CHECK-SVE-LABEL: test_fixed_extract: ; CHECK-SVE: // %bb.0: @@ -144,6 +145,21 @@ define void @test_fixed_extract(i64 %i, i64 %n) #0 { ; CHECK-SVE2p1-NEXT: mov v1.s[1], w11 ; CHECK-SVE2p1-NEXT: // kill: def $d1 killed $d1 killed $q1 ; CHECK-SVE2p1-NEXT: b use +; +; CHECK-SME2-LABEL: test_fixed_extract: +; CHECK-SME2: // %bb.0: +; CHECK-SME2-NEXT: whilelo p0.h, x0, x1 +; CHECK-SME2-NEXT: cset w8, mi +; CHECK-SME2-NEXT: mov z0.h, p0/z, #1 // =0x1 +; CHECK-SME2-NEXT: mov z1.h, z0.h[1] +; CHECK-SME2-NEXT: mov z2.h, z0.h[5] +; CHECK-SME2-NEXT: mov z3.h, z0.h[4] +; CHECK-SME2-NEXT: fmov s0, w8 +; CHECK-SME2-NEXT: zip1 z0.s, z0.s, z1.s +; CHECK-SME2-NEXT: zip1 z1.s, z3.s, z2.s +; CHECK-SME2-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-SME2-NEXT: // kill: def $d1 killed $d1 killed $z1 +; CHECK-SME2-NEXT: b use %r = call @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n) %v0 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64( %r, i64 0) %v1 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64( %r, i64 4) @@ -151,6 +167,67 @@ define void @test_fixed_extract(i64 %i, i64 %n) #0 { ret void } +; Illegal Types + +define void @test_2x16bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 { +; CHECK-SVE-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count: +; CHECK-SVE: // %bb.0: +; CHECK-SVE-NEXT: rdvl x8, #1 +; CHECK-SVE-NEXT: adds w8, w0, w8 +; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo +; CHECK-SVE-NEXT: whilelo p0.b, w0, w1 +; CHECK-SVE-NEXT: whilelo p1.b, w8, w1 +; CHECK-SVE-NEXT: b use +; +; CHECK-SVE2p1-SME2-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count: +; CHECK-SVE2p1-SME2: // %bb.0: +; CHECK-SVE2p1-SME2-NEXT: mov w8, w1 +; CHECK-SVE2p1-SME2-NEXT: mov w9, w0 +; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.b, p1.b }, x9, x8 +; CHECK-SVE2p1-SME2-NEXT: b use + %r = call @llvm.get.active.lane.mask.nxv32i1.i32(i32 %i, i32 %n) + %v0 = call @llvm.vector.extract.nxv16i1.nxv32i1.i64( %r, i64 0) + %v1 = call @llvm.vector.extract.nxv16i1.nxv32i1.i64( %r, i64 16) + tail call void @use( %v0, %v1) + ret void +} + +define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 { +; CHECK-SVE-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count: +; CHECK-SVE: // %bb.0: +; CHECK-SVE-NEXT: rdvl x8, #2 +; CHECK-SVE-NEXT: rdvl x9, #1 +; CHECK-SVE-NEXT: adds w8, w0, w8 +; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo +; CHECK-SVE-NEXT: adds w10, w8, w9 +; CHECK-SVE-NEXT: csinv w10, w10, wzr, lo +; CHECK-SVE-NEXT: whilelo p3.b, w10, w1 +; CHECK-SVE-NEXT: adds w9, w0, w9 +; CHECK-SVE-NEXT: csinv w9, w9, wzr, lo +; CHECK-SVE-NEXT: whilelo p0.b, w0, w1 +; CHECK-SVE-NEXT: whilelo p1.b, w9, w1 +; CHECK-SVE-NEXT: whilelo p2.b, w8, w1 +; CHECK-SVE-NEXT: b use +; +; CHECK-SVE2p1-SME2-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count: +; CHECK-SVE2p1-SME2: // %bb.0: +; CHECK-SVE2p1-SME2-NEXT: rdvl x8, #2 +; CHECK-SVE2p1-SME2-NEXT: mov w9, w1 +; CHECK-SVE2p1-SME2-NEXT: mov w10, w0 +; CHECK-SVE2p1-SME2-NEXT: adds w8, w0, w8 +; CHECK-SVE2p1-SME2-NEXT: csinv w8, w8, wzr, lo +; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.b, p1.b }, x10, x9 +; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.b, p3.b }, x8, x9 +; CHECK-SVE2p1-SME2-NEXT: b use + %r = call @llvm.get.active.lane.mask.nxv64i1.i32(i32 %i, i32 %n) + %v0 = call @llvm.vector.extract.nxv16i1.nxv64i1.i64( %r, i64 0) + %v1 = call @llvm.vector.extract.nxv16i1.nxv64i1.i64( %r, i64 16) + %v2 = call @llvm.vector.extract.nxv16i1.nxv64i1.i64( %r, i64 32) + %v3 = call @llvm.vector.extract.nxv16i1.nxv64i1.i64( %r, i64 48) + tail call void @use( %v0, %v1, %v2, %v3) + ret void +} + declare void @use(...) attributes #0 = { nounwind }