From 4ffe5e5cb8f33bfc0f5eceb62b75ea1288373ff9 Mon Sep 17 00:00:00 2001 From: Nick Guy Date: Wed, 28 May 2025 16:33:04 +0100 Subject: [PATCH 1/7] [AArch64][SelectionDAG] Add type legalization for partial reduce wide adds --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 12 +- .../Target/AArch64/AArch64ISelLowering.cpp | 35 ++ .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 13 + .../AArch64/sve-partial-reduce-dot-product.ll | 206 +++++++---- .../AArch64/sve-partial-reduce-wide-add.ll | 322 +++++++++++++----- 5 files changed, 426 insertions(+), 162 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 9e418329d15be..af504df596615 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12676,6 +12676,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { TLI.getTypeToTransformTo(*Context, LHSExtOpVT))) return SDValue(); + EVT ResultVT = N->getValueType(0); + bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND; unsigned NewOpcode = ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; @@ -12689,7 +12691,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C)) return SDValue(); - return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, + return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, DAG.getConstant(CTrunc, DL, LHSExtOpVT)); } @@ -12710,8 +12712,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { Op1.getValueType().getVectorElementType() != AccElemVT) return SDValue(); - return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, - RHSExtOp); + return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, RHSExtOp); } // partial.reduce.umla(acc, zext(op), splat(1)) @@ -12735,7 +12736,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { SDValue UnextOp1 = Op1.getOperand(0); EVT UnextOp1VT = UnextOp1.getValueType(); - if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT)) + auto *Context = DAG.getContext(); + if (!TLI.isPartialReduceMLALegalOrCustom( + TLI.getTypeToTransformTo(*Context, N->getValueType(0)), + TLI.getTypeToTransformTo(*Context, UnextOp1VT))) return SDValue(); bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a817ed5f0e917..fecdfe95a082d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1885,6 +1885,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal); setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom); + + // Wide add types + if (Subtarget->hasSVE2() || Subtarget->hasSME()) { + setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Custom); + setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Custom); + setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Custom); + } } // Handle operations that are only available in non-streaming SVE mode. @@ -29230,6 +29237,34 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, SDValue RHS = Op.getOperand(2); EVT ResultVT = Op.getValueType(); + // Recognise Op as a wide add, if it is then we leave it as-is + // Base: nxv2i64, Subdivision: nxv4i32 + auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool { + assert(Base.isVector() && Subdivision.isVector()); + assert(Base.isScalableVector() == Subdivision.isScalableVector()); + + ElementCount BaseCount = Base.getVectorElementCount(); + ElementCount SubCount = Subdivision.getVectorElementCount(); + if (BaseCount * 2 != SubCount) + return false; + + uint64_t BaseScalarSize = Base.getScalarSizeInBits(); + uint64_t SubScalarSize = Subdivision.getScalarSizeInBits(); + if (BaseScalarSize != SubScalarSize * 2) + return false; + + return true; + }; + if (IsEVTSubdivision(ResultVT, LHS.getValueType())) { + // If it looks like a real wide add, we can leave it as-is and treat it as + // Legal + APInt C; + if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne()) + return Op; + // If it doesn't, then we need to expand it. + return SDValue(); + } + assert((Scalable && ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8) || (!Scalable && ResultVT == MVT::v2i64 && diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index a40ef56f30486..1b1a24394e1f1 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -3813,6 +3813,19 @@ let Predicates = [HasSVE2_or_SME] in { defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>; defm USUBWT_ZZZ : sve2_wide_int_arith_wide<0b111, "usubwt", int_aarch64_sve_usubwt>; + def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))), + (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>; + def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))), + (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>; + def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))), + (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>; + def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))), + (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>; + def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))), + (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>; + def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))), + (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>; + // SVE2 integer multiply long defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>; defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>; diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index 809a45045b0db..a45b8b710c63a 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -561,31 +561,34 @@ define @udot_no_bin_op_8to64( %acc, %a to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64( %acc, %a.ext) ret %partial.reduce @@ -603,31 +606,34 @@ define @sdot_no_bin_op_8to64( %acc, %a to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64( %acc, %a.ext) ret %partial.reduce @@ -647,18 +653,44 @@ define @not_udot( %acc, % ; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s ; CHECK-NEXT: ret ; -; CHECK-NEWLOWERING-LABEL: not_udot: -; CHECK-NEWLOWERING: // %bb.0: // %entry -; CHECK-NEWLOWERING-NEXT: and z1.h, z1.h, #0xff -; CHECK-NEWLOWERING-NEXT: and z2.h, z2.h, #0xff -; CHECK-NEWLOWERING-NEXT: ptrue p0.s -; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h -; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h -; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h -; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h -; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s -; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s -; CHECK-NEWLOWERING-NEXT: ret +; CHECK-NEWLOWERING-SVE-LABEL: not_udot: +; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry +; CHECK-NEWLOWERING-SVE-NEXT: and z1.h, z1.h, #0xff +; CHECK-NEWLOWERING-SVE-NEXT: and z2.h, z2.h, #0xff +; CHECK-NEWLOWERING-SVE-NEXT: ptrue p0.s +; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.s, z1.h +; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z4.s, z2.h +; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.s, z2.h +; CHECK-NEWLOWERING-SVE-NEXT: mla z0.s, p0/m, z3.s, z4.s +; CHECK-NEWLOWERING-SVE-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NEWLOWERING-SVE-NEXT: ret +; +; CHECK-NEWLOWERING-SVE2-LABEL: not_udot: +; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry +; CHECK-NEWLOWERING-SVE2-NEXT: and z1.h, z1.h, #0xff +; CHECK-NEWLOWERING-SVE2-NEXT: and z2.h, z2.h, #0xff +; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.s +; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.s, z2.h +; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z4.s, z1.h +; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.s, z2.h +; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z4.s, z3.s +; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NEWLOWERING-SVE2-NEXT: ret +; +; CHECK-NEWLOWERING-SME-LABEL: not_udot: +; CHECK-NEWLOWERING-SME: // %bb.0: // %entry +; CHECK-NEWLOWERING-SME-NEXT: and z1.h, z1.h, #0xff +; CHECK-NEWLOWERING-SME-NEXT: and z2.h, z2.h, #0xff +; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.s +; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.s, z2.h +; CHECK-NEWLOWERING-SME-NEXT: uunpklo z4.s, z1.h +; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.s, z2.h +; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z4.s, z3.s +; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NEWLOWERING-SME-NEXT: ret entry: %a.wide = zext %a to %b.wide = zext %b to @@ -681,18 +713,44 @@ define @not_udot_wide( %acc, %a to %b.wide = zext %b to diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll index 5148d3da6c737..8f9f26a5d5b23 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll @@ -1,7 +1,8 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 -; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE2 -; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE -; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING +; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK-SVE2 +; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK-SVE +; RUN: llc -mtriple=aarch64 -mattr=+sve -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE +; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE2 define @signed_wide_add_nxv4i32( %acc, %input){ ; CHECK-SVE2-LABEL: signed_wide_add_nxv4i32: @@ -18,13 +19,19 @@ define @signed_wide_add_nxv4i32( %acc, %input to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64( %acc, %input.wide) @@ -46,13 +53,19 @@ define @unsigned_wide_add_nxv4i32( %acc, %input to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64( %acc, %input.wide) @@ -74,13 +87,19 @@ define @signed_wide_add_nxv8i16( %acc, %input to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32( %acc, %input.wide) @@ -102,13 +121,19 @@ define @unsigned_wide_add_nxv8i16( %acc, %input to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32( %acc, %input.wide) @@ -130,13 +155,19 @@ define @signed_wide_add_nxv16i8( %acc, %input to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16( %acc, %input.wide) @@ -158,13 +189,19 @@ define @unsigned_wide_add_nxv16i8( %acc, %input to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16( %acc, %input.wide) @@ -172,15 +209,43 @@ entry: } define @signed_wide_add_nxv4i16( %acc, %input){ -; CHECK-LABEL: signed_wide_add_nxv4i16: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ptrue p0.s -; CHECK-NEXT: sxth z1.s, p0/m, z1.s -; CHECK-NEXT: uunpklo z2.d, z1.s -; CHECK-NEXT: uunpkhi z1.d, z1.s -; CHECK-NEXT: add z0.d, z0.d, z2.d -; CHECK-NEXT: add z0.d, z1.d, z0.d -; CHECK-NEXT: ret +; CHECK-SVE2-LABEL: signed_wide_add_nxv4i16: +; CHECK-SVE2: // %bb.0: // %entry +; CHECK-SVE2-NEXT: ptrue p0.s +; CHECK-SVE2-NEXT: sxth z1.s, p0/m, z1.s +; CHECK-SVE2-NEXT: uunpklo z2.d, z1.s +; CHECK-SVE2-NEXT: uunpkhi z1.d, z1.s +; CHECK-SVE2-NEXT: add z0.d, z0.d, z2.d +; CHECK-SVE2-NEXT: add z0.d, z1.d, z0.d +; CHECK-SVE2-NEXT: ret +; +; CHECK-SVE-LABEL: signed_wide_add_nxv4i16: +; CHECK-SVE: // %bb.0: // %entry +; CHECK-SVE-NEXT: ptrue p0.s +; CHECK-SVE-NEXT: sxth z1.s, p0/m, z1.s +; CHECK-SVE-NEXT: uunpklo z2.d, z1.s +; CHECK-SVE-NEXT: uunpkhi z1.d, z1.s +; CHECK-SVE-NEXT: add z0.d, z0.d, z2.d +; CHECK-SVE-NEXT: add z0.d, z1.d, z0.d +; CHECK-SVE-NEXT: ret +; +; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv4i16: +; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry +; CHECK-NEWLOWERING-SVE-NEXT: ptrue p0.s +; CHECK-NEWLOWERING-SVE-NEXT: sxth z1.s, p0/m, z1.s +; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z2.d, z1.s +; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.d, z1.s +; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z1.d, z0.d +; CHECK-NEWLOWERING-SVE-NEXT: ret +; +; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv4i16: +; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry +; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.s +; CHECK-NEWLOWERING-SVE2-NEXT: sxth z1.s, p0/m, z1.s +; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z1.s +; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z1.s +; CHECK-NEWLOWERING-SVE2-NEXT: ret entry: %input.wide = sext %input to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32( %acc, %input.wide) @@ -188,14 +253,39 @@ entry: } define @unsigned_wide_add_nxv4i16( %acc, %input){ -; CHECK-LABEL: unsigned_wide_add_nxv4i16: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: and z1.s, z1.s, #0xffff -; CHECK-NEXT: uunpklo z2.d, z1.s -; CHECK-NEXT: uunpkhi z1.d, z1.s -; CHECK-NEXT: add z0.d, z0.d, z2.d -; CHECK-NEXT: add z0.d, z1.d, z0.d -; CHECK-NEXT: ret +; CHECK-SVE2-LABEL: unsigned_wide_add_nxv4i16: +; CHECK-SVE2: // %bb.0: // %entry +; CHECK-SVE2-NEXT: and z1.s, z1.s, #0xffff +; CHECK-SVE2-NEXT: uunpklo z2.d, z1.s +; CHECK-SVE2-NEXT: uunpkhi z1.d, z1.s +; CHECK-SVE2-NEXT: add z0.d, z0.d, z2.d +; CHECK-SVE2-NEXT: add z0.d, z1.d, z0.d +; CHECK-SVE2-NEXT: ret +; +; CHECK-SVE-LABEL: unsigned_wide_add_nxv4i16: +; CHECK-SVE: // %bb.0: // %entry +; CHECK-SVE-NEXT: and z1.s, z1.s, #0xffff +; CHECK-SVE-NEXT: uunpklo z2.d, z1.s +; CHECK-SVE-NEXT: uunpkhi z1.d, z1.s +; CHECK-SVE-NEXT: add z0.d, z0.d, z2.d +; CHECK-SVE-NEXT: add z0.d, z1.d, z0.d +; CHECK-SVE-NEXT: ret +; +; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv4i16: +; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry +; CHECK-NEWLOWERING-SVE-NEXT: and z1.s, z1.s, #0xffff +; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z2.d, z1.s +; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.d, z1.s +; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z1.d, z0.d +; CHECK-NEWLOWERING-SVE-NEXT: ret +; +; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv4i16: +; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry +; CHECK-NEWLOWERING-SVE2-NEXT: and z1.s, z1.s, #0xffff +; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z1.s +; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z1.s +; CHECK-NEWLOWERING-SVE2-NEXT: ret entry: %input.wide = zext %input to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32( %acc, %input.wide) @@ -203,17 +293,49 @@ entry: } define @signed_wide_add_nxv8i32( %acc, %input){ -; CHECK-LABEL: signed_wide_add_nxv8i32: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sunpklo z4.d, z3.s -; CHECK-NEXT: sunpklo z5.d, z2.s -; CHECK-NEXT: sunpkhi z3.d, z3.s -; CHECK-NEXT: sunpkhi z2.d, z2.s -; CHECK-NEXT: add z0.d, z0.d, z5.d -; CHECK-NEXT: add z1.d, z1.d, z4.d -; CHECK-NEXT: add z0.d, z0.d, z2.d -; CHECK-NEXT: add z1.d, z1.d, z3.d -; CHECK-NEXT: ret +; CHECK-SVE2-LABEL: signed_wide_add_nxv8i32: +; CHECK-SVE2: // %bb.0: // %entry +; CHECK-SVE2-NEXT: sunpklo z4.d, z3.s +; CHECK-SVE2-NEXT: sunpklo z5.d, z2.s +; CHECK-SVE2-NEXT: sunpkhi z3.d, z3.s +; CHECK-SVE2-NEXT: sunpkhi z2.d, z2.s +; CHECK-SVE2-NEXT: add z0.d, z0.d, z5.d +; CHECK-SVE2-NEXT: add z1.d, z1.d, z4.d +; CHECK-SVE2-NEXT: add z0.d, z0.d, z2.d +; CHECK-SVE2-NEXT: add z1.d, z1.d, z3.d +; CHECK-SVE2-NEXT: ret +; +; CHECK-SVE-LABEL: signed_wide_add_nxv8i32: +; CHECK-SVE: // %bb.0: // %entry +; CHECK-SVE-NEXT: sunpklo z4.d, z3.s +; CHECK-SVE-NEXT: sunpklo z5.d, z2.s +; CHECK-SVE-NEXT: sunpkhi z3.d, z3.s +; CHECK-SVE-NEXT: sunpkhi z2.d, z2.s +; CHECK-SVE-NEXT: add z0.d, z0.d, z5.d +; CHECK-SVE-NEXT: add z1.d, z1.d, z4.d +; CHECK-SVE-NEXT: add z0.d, z0.d, z2.d +; CHECK-SVE-NEXT: add z1.d, z1.d, z3.d +; CHECK-SVE-NEXT: ret +; +; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv8i32: +; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry +; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z4.d, z3.s +; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z5.d, z2.s +; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z3.d, z3.s +; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z2.d, z2.s +; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z5.d +; CHECK-NEWLOWERING-SVE-NEXT: add z1.d, z1.d, z4.d +; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEWLOWERING-SVE-NEXT: add z1.d, z1.d, z3.d +; CHECK-NEWLOWERING-SVE-NEXT: ret +; +; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv8i32: +; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry +; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z1.d, z1.d, z3.s +; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z2.s +; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z1.d, z1.d, z3.s +; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z2.s +; CHECK-NEWLOWERING-SVE2-NEXT: ret entry: %input.wide = sext %input to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64( %acc, %input.wide) @@ -221,17 +343,49 @@ entry: } define @unsigned_wide_add_nxv8i32( %acc, %input){ -; CHECK-LABEL: unsigned_wide_add_nxv8i32: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: uunpklo z4.d, z3.s -; CHECK-NEXT: uunpklo z5.d, z2.s -; CHECK-NEXT: uunpkhi z3.d, z3.s -; CHECK-NEXT: uunpkhi z2.d, z2.s -; CHECK-NEXT: add z0.d, z0.d, z5.d -; CHECK-NEXT: add z1.d, z1.d, z4.d -; CHECK-NEXT: add z0.d, z0.d, z2.d -; CHECK-NEXT: add z1.d, z1.d, z3.d -; CHECK-NEXT: ret +; CHECK-SVE2-LABEL: unsigned_wide_add_nxv8i32: +; CHECK-SVE2: // %bb.0: // %entry +; CHECK-SVE2-NEXT: uunpklo z4.d, z3.s +; CHECK-SVE2-NEXT: uunpklo z5.d, z2.s +; CHECK-SVE2-NEXT: uunpkhi z3.d, z3.s +; CHECK-SVE2-NEXT: uunpkhi z2.d, z2.s +; CHECK-SVE2-NEXT: add z0.d, z0.d, z5.d +; CHECK-SVE2-NEXT: add z1.d, z1.d, z4.d +; CHECK-SVE2-NEXT: add z0.d, z0.d, z2.d +; CHECK-SVE2-NEXT: add z1.d, z1.d, z3.d +; CHECK-SVE2-NEXT: ret +; +; CHECK-SVE-LABEL: unsigned_wide_add_nxv8i32: +; CHECK-SVE: // %bb.0: // %entry +; CHECK-SVE-NEXT: uunpklo z4.d, z3.s +; CHECK-SVE-NEXT: uunpklo z5.d, z2.s +; CHECK-SVE-NEXT: uunpkhi z3.d, z3.s +; CHECK-SVE-NEXT: uunpkhi z2.d, z2.s +; CHECK-SVE-NEXT: add z0.d, z0.d, z5.d +; CHECK-SVE-NEXT: add z1.d, z1.d, z4.d +; CHECK-SVE-NEXT: add z0.d, z0.d, z2.d +; CHECK-SVE-NEXT: add z1.d, z1.d, z3.d +; CHECK-SVE-NEXT: ret +; +; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv8i32: +; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry +; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z4.d, z3.s +; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z5.d, z2.s +; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z3.d, z3.s +; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z2.s +; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z5.d +; CHECK-NEWLOWERING-SVE-NEXT: add z1.d, z1.d, z4.d +; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEWLOWERING-SVE-NEXT: add z1.d, z1.d, z3.d +; CHECK-NEWLOWERING-SVE-NEXT: ret +; +; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv8i32: +; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry +; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z1.d, z1.d, z3.s +; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z2.s +; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z1.d, z1.d, z3.s +; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z2.s +; CHECK-NEWLOWERING-SVE2-NEXT: ret entry: %input.wide = zext %input to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64( %acc, %input.wide) From 6a7b359b3346c0751fb6728d3d2aa31ecaed46dd Mon Sep 17 00:00:00 2001 From: Nick Guy Date: Tue, 27 May 2025 17:29:00 +0100 Subject: [PATCH 2/7] Replace custom lowering with tablegen patterns. --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 7 +-- .../Target/AArch64/AArch64ISelLowering.cpp | 63 ++++++++++--------- .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 13 ++++ .../AArch64/sve-partial-reduce-dot-product.ll | 44 ++++--------- 4 files changed, 60 insertions(+), 67 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index af504df596615..b0ce39010c97c 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12676,8 +12676,6 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { TLI.getTypeToTransformTo(*Context, LHSExtOpVT))) return SDValue(); - EVT ResultVT = N->getValueType(0); - bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND; unsigned NewOpcode = ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; @@ -12691,7 +12689,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { (LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C)) return SDValue(); - return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, + return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, DAG.getConstant(CTrunc, DL, LHSExtOpVT)); } @@ -12712,7 +12710,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { Op1.getValueType().getVectorElementType() != AccElemVT) return SDValue(); - return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, RHSExtOp); + return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, + RHSExtOp); } // partial.reduce.umla(acc, zext(op), splat(1)) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index fecdfe95a082d..0120eba2c894c 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1888,9 +1888,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, // Wide add types if (Subtarget->hasSVE2() || Subtarget->hasSME()) { - setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Custom); - setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Custom); - setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Custom); + setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal); + setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal); + setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal); } } @@ -29236,34 +29236,35 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, SDValue LHS = Op.getOperand(1); SDValue RHS = Op.getOperand(2); EVT ResultVT = Op.getValueType(); - - // Recognise Op as a wide add, if it is then we leave it as-is - // Base: nxv2i64, Subdivision: nxv4i32 - auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool { - assert(Base.isVector() && Subdivision.isVector()); - assert(Base.isScalableVector() == Subdivision.isScalableVector()); - - ElementCount BaseCount = Base.getVectorElementCount(); - ElementCount SubCount = Subdivision.getVectorElementCount(); - if (BaseCount * 2 != SubCount) - return false; - - uint64_t BaseScalarSize = Base.getScalarSizeInBits(); - uint64_t SubScalarSize = Subdivision.getScalarSizeInBits(); - if (BaseScalarSize != SubScalarSize * 2) - return false; - - return true; - }; - if (IsEVTSubdivision(ResultVT, LHS.getValueType())) { - // If it looks like a real wide add, we can leave it as-is and treat it as - // Legal - APInt C; - if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne()) - return Op; - // If it doesn't, then we need to expand it. - return SDValue(); - } + // + // // Recognise Op as a wide add, if it is then we leave it as-is + // // Base: nxv2i64, Subdivision: nxv4i32 + // auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool { + // assert(Base.isVector() && Subdivision.isVector()); + // assert(Base.isScalableVector() == Subdivision.isScalableVector()); + // + // ElementCount BaseCount = Base.getVectorElementCount(); + // ElementCount SubCount = Subdivision.getVectorElementCount(); + // if (BaseCount * 2 != SubCount) + // return false; + // + // uint64_t BaseScalarSize = Base.getScalarSizeInBits(); + // uint64_t SubScalarSize = Subdivision.getScalarSizeInBits(); + // if (BaseScalarSize != SubScalarSize * 2) + // return false; + // + // return true; + // }; + // if (IsEVTSubdivision(ResultVT, LHS.getValueType())) { + // // If it looks like a real wide add, we can leave it as-is and treat it + // as + // // Legal + // APInt C; + // if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne()) + // return Op; + // // If it doesn't, then we need to expand it. + // return SDValue(); + // } assert((Scalable && ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8) || diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 1b1a24394e1f1..487650f9ad9c0 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -3826,6 +3826,19 @@ let Predicates = [HasSVE2_or_SME] in { def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))), (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>; + def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), + (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), + (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)), + (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)), + (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), + (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), + (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + // SVE2 integer multiply long defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>; defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>; diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index a45b8b710c63a..203606c8ffacc 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -668,28 +668,18 @@ define @not_udot( %acc, % ; ; CHECK-NEWLOWERING-SVE2-LABEL: not_udot: ; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry -; CHECK-NEWLOWERING-SVE2-NEXT: and z1.h, z1.h, #0xff ; CHECK-NEWLOWERING-SVE2-NEXT: and z2.h, z2.h, #0xff -; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.s -; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.s, z2.h -; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z4.s, z1.h -; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.s, z2.h -; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z1.s, z1.h -; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z4.s, z3.s -; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NEWLOWERING-SVE2-NEXT: and z1.h, z1.h, #0xff +; CHECK-NEWLOWERING-SVE2-NEXT: umlalb z0.h, z1.b, z2.b +; CHECK-NEWLOWERING-SVE2-NEXT: umlalt z0.h, z1.b, z2.b ; CHECK-NEWLOWERING-SVE2-NEXT: ret ; ; CHECK-NEWLOWERING-SME-LABEL: not_udot: ; CHECK-NEWLOWERING-SME: // %bb.0: // %entry -; CHECK-NEWLOWERING-SME-NEXT: and z1.h, z1.h, #0xff ; CHECK-NEWLOWERING-SME-NEXT: and z2.h, z2.h, #0xff -; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.s -; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.s, z2.h -; CHECK-NEWLOWERING-SME-NEXT: uunpklo z4.s, z1.h -; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.s, z2.h -; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z1.s, z1.h -; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z4.s, z3.s -; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NEWLOWERING-SME-NEXT: and z1.h, z1.h, #0xff +; CHECK-NEWLOWERING-SME-NEXT: umlalb z0.h, z1.b, z2.b +; CHECK-NEWLOWERING-SME-NEXT: umlalt z0.h, z1.b, z2.b ; CHECK-NEWLOWERING-SME-NEXT: ret entry: %a.wide = zext %a to @@ -728,28 +718,18 @@ define @not_udot_wide( %acc, %a to From ba78d71047d5c6103ac16b112d0560ab797bff3a Mon Sep 17 00:00:00 2001 From: Nick Guy Date: Tue, 27 May 2025 18:08:52 +0100 Subject: [PATCH 3/7] Remove dead code --- .../Target/AArch64/AArch64ISelLowering.cpp | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 0120eba2c894c..2f47f2610b78c 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -29236,35 +29236,6 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, SDValue LHS = Op.getOperand(1); SDValue RHS = Op.getOperand(2); EVT ResultVT = Op.getValueType(); - // - // // Recognise Op as a wide add, if it is then we leave it as-is - // // Base: nxv2i64, Subdivision: nxv4i32 - // auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool { - // assert(Base.isVector() && Subdivision.isVector()); - // assert(Base.isScalableVector() == Subdivision.isScalableVector()); - // - // ElementCount BaseCount = Base.getVectorElementCount(); - // ElementCount SubCount = Subdivision.getVectorElementCount(); - // if (BaseCount * 2 != SubCount) - // return false; - // - // uint64_t BaseScalarSize = Base.getScalarSizeInBits(); - // uint64_t SubScalarSize = Subdivision.getScalarSizeInBits(); - // if (BaseScalarSize != SubScalarSize * 2) - // return false; - // - // return true; - // }; - // if (IsEVTSubdivision(ResultVT, LHS.getValueType())) { - // // If it looks like a real wide add, we can leave it as-is and treat it - // as - // // Legal - // APInt C; - // if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne()) - // return Op; - // // If it doesn't, then we need to expand it. - // return SDValue(); - // } assert((Scalable && ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8) || From 9aca4589b98d1f9c84dd162c693279ef395971e9 Mon Sep 17 00:00:00 2001 From: Nick Guy Date: Wed, 28 May 2025 16:34:08 +0100 Subject: [PATCH 4/7] Use correct instructions for types --- .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 487650f9ad9c0..51cadc3b73c31 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -3818,25 +3818,25 @@ let Predicates = [HasSVE2_or_SME] in { def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))), (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>; def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))), - (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>; + (UADDWT_ZZZ_S (UADDWB_ZZZ_S $Acc, $Input), $Input)>; def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))), - (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>; + (SADDWT_ZZZ_S (SADDWB_ZZZ_S $Acc, $Input), $Input)>; def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))), - (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>; + (UADDWT_ZZZ_H (UADDWB_ZZZ_H $Acc, $Input), $Input)>; def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))), - (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>; + (SADDWT_ZZZ_H (SADDWB_ZZZ_H $Acc, $Input), $Input)>; - def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), - (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; - def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), - (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), + (UMLALT_ZZZ_D (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), + (SMLALT_ZZZ_D (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)), - (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + (UMLALT_ZZZ_S (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)), - (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; - def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), + (SMLALT_ZZZ_S (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; - def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), + def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; // SVE2 integer multiply long From 4c514eac0355e14e4ffa75885e43b1ff89f58162 Mon Sep 17 00:00:00 2001 From: Nick Guy Date: Wed, 28 May 2025 15:52:30 +0100 Subject: [PATCH 5/7] Update tests --- .../AArch64/sve-partial-reduce-dot-product.ll | 8 ++++---- .../AArch64/sve-partial-reduce-wide-add.ll | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index 203606c8ffacc..55c879deb6217 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -671,7 +671,7 @@ define @not_udot( %acc, % ; CHECK-NEWLOWERING-SVE2-NEXT: and z2.h, z2.h, #0xff ; CHECK-NEWLOWERING-SVE2-NEXT: and z1.h, z1.h, #0xff ; CHECK-NEWLOWERING-SVE2-NEXT: umlalb z0.h, z1.b, z2.b -; CHECK-NEWLOWERING-SVE2-NEXT: umlalt z0.h, z1.b, z2.b +; CHECK-NEWLOWERING-SVE2-NEXT: umlalt z0.s, z1.h, z2.h ; CHECK-NEWLOWERING-SVE2-NEXT: ret ; ; CHECK-NEWLOWERING-SME-LABEL: not_udot: @@ -679,7 +679,7 @@ define @not_udot( %acc, % ; CHECK-NEWLOWERING-SME-NEXT: and z2.h, z2.h, #0xff ; CHECK-NEWLOWERING-SME-NEXT: and z1.h, z1.h, #0xff ; CHECK-NEWLOWERING-SME-NEXT: umlalb z0.h, z1.b, z2.b -; CHECK-NEWLOWERING-SME-NEXT: umlalt z0.h, z1.b, z2.b +; CHECK-NEWLOWERING-SME-NEXT: umlalt z0.s, z1.h, z2.h ; CHECK-NEWLOWERING-SME-NEXT: ret entry: %a.wide = zext %a to @@ -721,7 +721,7 @@ define @not_udot_wide( %acc, @not_udot_wide( %acc, %a to diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll index 8f9f26a5d5b23..428dd4c3a0154 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll @@ -97,8 +97,8 @@ define @signed_wide_add_nxv8i16( %acc, %input to @@ -131,8 +131,8 @@ define @unsigned_wide_add_nxv8i16( %acc, %input to @@ -165,8 +165,8 @@ define @signed_wide_add_nxv16i8( %acc, %input to @@ -199,8 +199,8 @@ define @unsigned_wide_add_nxv16i8( %acc, %input to From 26c1098be6578d514d4bfa8c9b5b1e31c3dbb33f Mon Sep 17 00:00:00 2001 From: Nick Guy Date: Wed, 28 May 2025 16:55:59 +0100 Subject: [PATCH 6/7] Update tests after rebase --- .../neon-partial-reduce-dot-product.ll | 30 ++++--------------- .../AArch64/sve-partial-reduce-dot-product.ll | 12 ++++---- 2 files changed, 12 insertions(+), 30 deletions(-) diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll index 2b68c963ad319..d977d8fc9cf21 100644 --- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll @@ -917,20 +917,11 @@ define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){ ; ; CHECK-NEWLOWERING-I8MM-LABEL: udot_no_bin_op_8to64: ; CHECK-NEWLOWERING-I8MM: // %bb.0: -; CHECK-NEWLOWERING-I8MM-NEXT: ushll v3.8h, v2.8b, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.4s, v3.4h, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: ushll v5.4s, v2.4h, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.4s, v3.8h, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.4s, v2.8h, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: uaddw v1.2d, v1.2d, v5.2s +; CHECK-NEWLOWERING-I8MM-NEXT: movi v3.16b, #1 +; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000 +; CHECK-NEWLOWERING-I8MM-NEXT: udot v4.4s, v2.16b, v3.16b ; CHECK-NEWLOWERING-I8MM-NEXT: uaddw v0.2d, v0.2d, v4.2s -; CHECK-NEWLOWERING-I8MM-NEXT: uaddw2 v1.2d, v1.2d, v5.4s ; CHECK-NEWLOWERING-I8MM-NEXT: uaddw2 v0.2d, v0.2d, v4.4s -; CHECK-NEWLOWERING-I8MM-NEXT: uaddw v1.2d, v1.2d, v2.2s -; CHECK-NEWLOWERING-I8MM-NEXT: uaddw v0.2d, v0.2d, v3.2s -; CHECK-NEWLOWERING-I8MM-NEXT: uaddw2 v1.2d, v1.2d, v2.4s -; CHECK-NEWLOWERING-I8MM-NEXT: uaddw2 v0.2d, v0.2d, v3.4s ; CHECK-NEWLOWERING-I8MM-NEXT: ret %a.wide = zext <16 x i8> %a to <16 x i64> %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide) @@ -967,20 +958,11 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){ ; ; CHECK-NEWLOWERING-I8MM-LABEL: sdot_no_bin_op_8to64: ; CHECK-NEWLOWERING-I8MM: // %bb.0: -; CHECK-NEWLOWERING-I8MM-NEXT: sshll v3.8h, v2.8b, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.4s, v3.4h, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.4s, v2.4h, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.4s, v3.8h, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.4s, v2.8h, #0 -; CHECK-NEWLOWERING-I8MM-NEXT: saddw v1.2d, v1.2d, v5.2s +; CHECK-NEWLOWERING-I8MM-NEXT: movi v3.16b, #1 +; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000 +; CHECK-NEWLOWERING-I8MM-NEXT: sdot v4.4s, v2.16b, v3.16b ; CHECK-NEWLOWERING-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s -; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v1.2d, v1.2d, v5.4s ; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v0.2d, v0.2d, v4.4s -; CHECK-NEWLOWERING-I8MM-NEXT: saddw v1.2d, v1.2d, v2.2s -; CHECK-NEWLOWERING-I8MM-NEXT: saddw v0.2d, v0.2d, v3.2s -; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v1.2d, v1.2d, v2.4s -; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v0.2d, v0.2d, v3.4s ; CHECK-NEWLOWERING-I8MM-NEXT: ret %a.wide = sext <16 x i8> %a to <16 x i64> %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide) diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index 55c879deb6217..006083d843370 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -566,10 +566,10 @@ define @udot_no_bin_op_8to64( %acc, @sdot_no_bin_op_8to64( %acc, Date: Thu, 29 May 2025 12:49:00 +0100 Subject: [PATCH 7/7] Use correctly-typed instructions for the lower half too --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 8 ++++---- .../CodeGen/AArch64/sve-partial-reduce-dot-product.ll | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 51cadc3b73c31..91db6b6fc7984 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -3827,13 +3827,13 @@ let Predicates = [HasSVE2_or_SME] in { (SADDWT_ZZZ_H (SADDWB_ZZZ_H $Acc, $Input), $Input)>; def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), - (UMLALT_ZZZ_D (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + (UMLALT_ZZZ_D (UMLALB_ZZZ_D $Acc, $LHS, $RHS), $LHS, $RHS)>; def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)), - (SMLALT_ZZZ_D (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + (SMLALT_ZZZ_D (SMLALB_ZZZ_D $Acc, $LHS, $RHS), $LHS, $RHS)>; def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)), - (UMLALT_ZZZ_S (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + (UMLALT_ZZZ_S (UMLALB_ZZZ_S $Acc, $LHS, $RHS), $LHS, $RHS)>; def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)), - (SMLALT_ZZZ_S (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; + (SMLALT_ZZZ_S (SMLALB_ZZZ_S $Acc, $LHS, $RHS), $LHS, $RHS)>; def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>; def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)), diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index 006083d843370..d3ccfaaf20a22 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -670,7 +670,7 @@ define @not_udot( %acc, % ; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry ; CHECK-NEWLOWERING-SVE2-NEXT: and z2.h, z2.h, #0xff ; CHECK-NEWLOWERING-SVE2-NEXT: and z1.h, z1.h, #0xff -; CHECK-NEWLOWERING-SVE2-NEXT: umlalb z0.h, z1.b, z2.b +; CHECK-NEWLOWERING-SVE2-NEXT: umlalb z0.s, z1.h, z2.h ; CHECK-NEWLOWERING-SVE2-NEXT: umlalt z0.s, z1.h, z2.h ; CHECK-NEWLOWERING-SVE2-NEXT: ret ; @@ -678,7 +678,7 @@ define @not_udot( %acc, % ; CHECK-NEWLOWERING-SME: // %bb.0: // %entry ; CHECK-NEWLOWERING-SME-NEXT: and z2.h, z2.h, #0xff ; CHECK-NEWLOWERING-SME-NEXT: and z1.h, z1.h, #0xff -; CHECK-NEWLOWERING-SME-NEXT: umlalb z0.h, z1.b, z2.b +; CHECK-NEWLOWERING-SME-NEXT: umlalb z0.s, z1.h, z2.h ; CHECK-NEWLOWERING-SME-NEXT: umlalt z0.s, z1.h, z2.h ; CHECK-NEWLOWERING-SME-NEXT: ret entry: @@ -720,7 +720,7 @@ define @not_udot_wide( %acc, @not_udot_wide( %acc,