Skip to content

Commit daa1cdb

Browse files
[AArch64][SVE] Add codegen support for partial reduction lowering to wide add instructions
1 parent 7d1e98c commit daa1cdb

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2042,7 +2042,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
20422042

20432043
EVT VT = EVT::getEVT(I->getType());
20442044
return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
2045-
VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
2045+
VT != MVT::nxv8i16 && VT != MVT::v4i64 && VT != MVT::v4i32 &&
2046+
VT != MVT::v2i32 && VT != MVT::v8i16;
20462047
}
20472048

20482049
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21783,6 +21784,62 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2178321784
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
2178421785
}
2178521786

21787+
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
21788+
const AArch64Subtarget *Subtarget,
21789+
SelectionDAG &DAG) {
21790+
21791+
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
21792+
getIntrinsicID(N) ==
21793+
Intrinsic::experimental_vector_partial_reduce_add &&
21794+
"Expected a partial reduction node");
21795+
21796+
bool Scalable = N->getValueType(0).isScalableVector();
21797+
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
21798+
return SDValue();
21799+
21800+
SDLoc DL(N);
21801+
21802+
auto Accumulator = N->getOperand(1);
21803+
auto ExtInput = N->getOperand(2);
21804+
21805+
EVT AccumulatorType = Accumulator.getValueType();
21806+
EVT AccumulatorElementType = AccumulatorType.getVectorElementType();
21807+
21808+
if (ExtInput.getValueType().getVectorElementType() != AccumulatorElementType)
21809+
return SDValue();
21810+
21811+
unsigned ExtInputOpcode = ExtInput->getOpcode();
21812+
if (!ISD::isExtOpcode(ExtInputOpcode))
21813+
return SDValue();
21814+
21815+
auto Input = ExtInput->getOperand(0);
21816+
EVT InputType = Input.getValueType();
21817+
21818+
// To do this transformation, output element size needs to be double input
21819+
// element size, and output number of elements needs to be half the input
21820+
// number of elements
21821+
if (!(InputType.getVectorElementType().getSizeInBits() * 2 ==
21822+
AccumulatorElementType.getSizeInBits()) ||
21823+
!(AccumulatorType.getVectorElementCount() * 2 ==
21824+
InputType.getVectorElementCount()) ||
21825+
!(AccumulatorType.isScalableVector() == InputType.isScalableVector()))
21826+
return SDValue();
21827+
21828+
bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
21829+
auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb
21830+
: Intrinsic::aarch64_sve_uaddwb;
21831+
auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
21832+
: Intrinsic::aarch64_sve_uaddwt;
21833+
21834+
auto BottomID =
21835+
DAG.getTargetConstant(BottomIntrinsic, DL, AccumulatorElementType);
21836+
auto BottomNode = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType,
21837+
BottomID, Accumulator, Input);
21838+
auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccumulatorElementType);
21839+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType, TopID,
21840+
BottomNode, Input);
21841+
}
21842+
2178621843
static SDValue performIntrinsicCombine(SDNode *N,
2178721844
TargetLowering::DAGCombinerInfo &DCI,
2178821845
const AArch64Subtarget *Subtarget) {
@@ -21794,6 +21851,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
2179421851
case Intrinsic::experimental_vector_partial_reduce_add: {
2179521852
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
2179621853
return Dot;
21854+
if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
21855+
return WideAdd;
2179721856
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
2179821857
N->getOperand(1), N->getOperand(2));
2179921858
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
3+
4+
define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
5+
; CHECK-LABEL: signed_wide_add_nxv4i32:
6+
; CHECK: // %bb.0: // %entry
7+
; CHECK-NEXT: saddwb z0.d, z0.d, z1.s
8+
; CHECK-NEXT: saddwt z0.d, z0.d, z1.s
9+
; CHECK-NEXT: ret
10+
entry:
11+
%input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
12+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
13+
ret <vscale x 2 x i64> %partial.reduce
14+
}
15+
16+
define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
17+
; CHECK-LABEL: unsigned_wide_add_nxv4i32:
18+
; CHECK: // %bb.0: // %entry
19+
; CHECK-NEXT: uaddwb z0.d, z0.d, z1.s
20+
; CHECK-NEXT: uaddwt z0.d, z0.d, z1.s
21+
; CHECK-NEXT: ret
22+
entry:
23+
%input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
24+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
25+
ret <vscale x 2 x i64> %partial.reduce
26+
}
27+
28+
define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
29+
; CHECK-LABEL: signed_wide_add_nxv8i16:
30+
; CHECK: // %bb.0: // %entry
31+
; CHECK-NEXT: saddwb z0.s, z0.s, z1.h
32+
; CHECK-NEXT: saddwt z0.s, z0.s, z1.h
33+
; CHECK-NEXT: ret
34+
entry:
35+
%input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
36+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
37+
ret <vscale x 4 x i32> %partial.reduce
38+
}
39+
40+
define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
41+
; CHECK-LABEL: unsigned_wide_add_nxv8i16:
42+
; CHECK: // %bb.0: // %entry
43+
; CHECK-NEXT: uaddwb z0.s, z0.s, z1.h
44+
; CHECK-NEXT: uaddwt z0.s, z0.s, z1.h
45+
; CHECK-NEXT: ret
46+
entry:
47+
%input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
48+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
49+
ret <vscale x 4 x i32> %partial.reduce
50+
}
51+
52+
define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
53+
; CHECK-LABEL: signed_wide_add_nxv16i8:
54+
; CHECK: // %bb.0: // %entry
55+
; CHECK-NEXT: saddwb z0.h, z0.h, z1.b
56+
; CHECK-NEXT: saddwt z0.h, z0.h, z1.b
57+
; CHECK-NEXT: ret
58+
entry:
59+
%input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
60+
%partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
61+
ret <vscale x 8 x i16> %partial.reduce
62+
}
63+
64+
define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
65+
; CHECK-LABEL: unsigned_wide_add_nxv16i8:
66+
; CHECK: // %bb.0: // %entry
67+
; CHECK-NEXT: uaddwb z0.h, z0.h, z1.b
68+
; CHECK-NEXT: uaddwt z0.h, z0.h, z1.b
69+
; CHECK-NEXT: ret
70+
entry:
71+
%input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
72+
%partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
73+
ret <vscale x 8 x i16> %partial.reduce
74+
}

0 commit comments

Comments
 (0)