Skip to content

Commit 1d0994b

Browse files
committed
Address comments and add SVE2/SME guard & fallback
1 parent 00371df commit 1d0994b

File tree

2 files changed

+60
-26
lines changed

2 files changed

+60
-26
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29514,26 +29514,39 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2951429514
return Scatter;
2951529515
}
2951629516

29517+
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator type that is too
29518+
/// wide to be used for (u|s)dot, we can still make use of the dot product
29519+
/// instruction by instead treating the accumulator as a vector type with twice
29520+
/// as many elements that are each half as wide, accumulating the low and high
29521+
/// parts of the result together in the actual accumulator afterwards.
2951729522
SDValue
2951829523
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2951929524
SelectionDAG &DAG) const {
2952029525
SDLoc DL(Op);
2952129526

29522-
auto Acc = Op.getOperand(0);
29523-
auto LHS = Op.getOperand(1);
29524-
auto RHS = Op.getOperand(2);
29525-
auto ResultVT = Op.getValueType();
29527+
SDValue Acc = Op.getOperand(0);
29528+
SDValue LHS = Op.getOperand(1);
29529+
SDValue RHS = Op.getOperand(2);
29530+
EVT ResultVT = Op.getValueType();
2952629531
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
2952729532

29528-
EVT InputVT = MVT::nxv4i32;
29529-
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, InputVT,
29530-
DAG.getConstant(0, DL, InputVT), LHS, RHS);
29533+
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
29534+
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
2953129535

2953229536
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29533-
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
29534-
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
29535-
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29536-
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29537+
if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
29538+
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
29539+
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
29540+
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29541+
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29542+
}
29543+
29544+
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
29545+
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
29546+
auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
29547+
auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
29548+
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
29549+
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
2953729550
}
2953829551

2953929552
SDValue

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
22
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
33
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
4-
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING
4+
; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
5+
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
56

67
define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
78
; CHECK-LABEL: udot:
@@ -196,13 +197,23 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
196197
; CHECK-NEXT: add z1.d, z1.d, z3.d
197198
; CHECK-NEXT: ret
198199
;
199-
; CHECK-NEWLOWERING-LABEL: udot_8to64:
200-
; CHECK-NEWLOWERING: // %bb.0: // %entry
201-
; CHECK-NEWLOWERING-NEXT: movi v4.2d, #0000000000000000
202-
; CHECK-NEWLOWERING-NEXT: udot z4.s, z2.b, z3.b
203-
; CHECK-NEWLOWERING-NEXT: uaddwb z0.d, z0.d, z4.s
204-
; CHECK-NEWLOWERING-NEXT: uaddwt z0.d, z0.d, z4.s
205-
; CHECK-NEWLOWERING-NEXT: ret
200+
; CHECK-NEWLOWERING-SVE-LABEL: udot_8to64:
201+
; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
202+
; CHECK-NEWLOWERING-SVE-NEXT: movi v4.2d, #0000000000000000
203+
; CHECK-NEWLOWERING-SVE-NEXT: udot z4.s, z2.b, z3.b
204+
; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z4.s
205+
; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.d, z4.s
206+
; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
207+
; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
208+
; CHECK-NEWLOWERING-SVE-NEXT: ret
209+
;
210+
; CHECK-NEWLOWERING-SVE2-LABEL: udot_8to64:
211+
; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
212+
; CHECK-NEWLOWERING-SVE2-NEXT: movi v4.2d, #0000000000000000
213+
; CHECK-NEWLOWERING-SVE2-NEXT: udot z4.s, z2.b, z3.b
214+
; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z4.s
215+
; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z4.s
216+
; CHECK-NEWLOWERING-SVE2-NEXT: ret
206217
entry:
207218
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
208219
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -223,13 +234,23 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
223234
; CHECK-NEXT: add z1.d, z1.d, z3.d
224235
; CHECK-NEXT: ret
225236
;
226-
; CHECK-NEWLOWERING-LABEL: sdot_8to64:
227-
; CHECK-NEWLOWERING: // %bb.0: // %entry
228-
; CHECK-NEWLOWERING-NEXT: movi v4.2d, #0000000000000000
229-
; CHECK-NEWLOWERING-NEXT: sdot z4.s, z2.b, z3.b
230-
; CHECK-NEWLOWERING-NEXT: saddwb z0.d, z0.d, z4.s
231-
; CHECK-NEWLOWERING-NEXT: saddwt z0.d, z0.d, z4.s
232-
; CHECK-NEWLOWERING-NEXT: ret
237+
; CHECK-NEWLOWERING-SVE-LABEL: sdot_8to64:
238+
; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
239+
; CHECK-NEWLOWERING-SVE-NEXT: movi v4.2d, #0000000000000000
240+
; CHECK-NEWLOWERING-SVE-NEXT: sdot z4.s, z2.b, z3.b
241+
; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z2.d, z4.s
242+
; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z3.d, z4.s
243+
; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
244+
; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
245+
; CHECK-NEWLOWERING-SVE-NEXT: ret
246+
;
247+
; CHECK-NEWLOWERING-SVE2-LABEL: sdot_8to64:
248+
; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
249+
; CHECK-NEWLOWERING-SVE2-NEXT: movi v4.2d, #0000000000000000
250+
; CHECK-NEWLOWERING-SVE2-NEXT: sdot z4.s, z2.b, z3.b
251+
; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z4.s
252+
; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z4.s
253+
; CHECK-NEWLOWERING-SVE2-NEXT: ret
233254
entry:
234255
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
235256
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>

0 commit comments

Comments
 (0)