diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 0126b97c9fb9a..1bc1208e41b56 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1867,6 +1867,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, // Other pairs will default to 'Expand'. setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal); setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal); + + setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom); } // Handle operations that are only available in non-streaming SVE mode. @@ -7767,6 +7769,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerFLDEXP(Op, DAG); case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return LowerVECTOR_HISTOGRAM(Op, DAG); + case ISD::PARTIAL_REDUCE_SMLA: + case ISD::PARTIAL_REDUCE_UMLA: + return LowerPARTIAL_REDUCE_MLA(Op, DAG); } } @@ -29509,6 +29514,40 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op, return Scatter; } +/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing +/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can +/// however still make use of the dot product instruction by instead +/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64. +SDValue +AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + + SDValue Acc = Op.getOperand(0); + SDValue LHS = Op.getOperand(1); + SDValue RHS = Op.getOperand(2); + EVT ResultVT = Op.getValueType(); + assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8); + + SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32, + DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS); + + bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA; + if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) { + unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB; + unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT; + SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode); + return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode); + } + + unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO; + unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI; + auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode); + auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode); + auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi); + return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended); +} + SDValue AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op, SelectionDAG &DAG) const { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index d9b535b910b80..9d8d1c22258be 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1181,6 +1181,7 @@ class AArch64TargetLowering : public TargetLowering { SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const; SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index 039cac01008b8..709c519a8b7c4 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -1,7 +1,9 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 ; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM -; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING +; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE +; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2 +; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME define @udot( %acc, %a, %b) { ; CHECK-LABEL: udot: @@ -196,46 +198,31 @@ define @udot_8to64( %acc, %a to %b.wide = zext %b to @@ -256,46 +243,31 @@ define @sdot_8to64( %acc, %a to %b.wide = sext %b to