Skip to content

Commit d74ba76

Browse files
committed
[CodeGen] Implement widening for partial.reduce.add
Widening of accumulator/result is done by padding the accumulator with zero elements, performing the partial reduction and then partially reducing the wide vector result (using extract lo/hi + add) into the narrow part of the result vector. Widening of the input vector is done by padding it with zero elements.
1 parent 224a717 commit d74ba76

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
11171117
SDValue WidenVecRes_Unary(SDNode *N);
11181118
SDValue WidenVecRes_InregOp(SDNode *N);
11191119
SDValue WidenVecRes_UnaryOpWithTwoResults(SDNode *N, unsigned ResNo);
1120+
SDValue WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
11201121
void ReplaceOtherWidenResults(SDNode *N, SDNode *WidenNode,
11211122
unsigned WidenResNo);
11221123

@@ -1152,6 +1153,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
11521153
SDValue WidenVecOp_VP_REDUCE(SDNode *N);
11531154
SDValue WidenVecOp_ExpOp(SDNode *N);
11541155
SDValue WidenVecOp_VP_CttzElements(SDNode *N);
1156+
SDValue WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N);
11551157

11561158
/// Helper function to generate a set of operations to perform
11571159
/// a vector operation for a wider type.

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5136,6 +5136,10 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
51365136
if (!unrollExpandedOp())
51375137
Res = WidenVecRes_UnaryOpWithTwoResults(N, ResNo);
51385138
break;
5139+
case ISD::PARTIAL_REDUCE_UMLA:
5140+
case ISD::PARTIAL_REDUCE_SMLA:
5141+
Res = WidenVecRes_PARTIAL_REDUCE_MLA(N);
5142+
break;
51395143
}
51405144
}
51415145

@@ -6995,6 +6999,34 @@ SDValue DAGTypeLegalizer::WidenVecRes_STRICT_FSETCC(SDNode *N) {
69956999
return DAG.getBuildVector(WidenVT, dl, Scalars);
69967000
}
69977001

7002+
// Widening the result of a partial reductions is implemented by
7003+
// accumulating into a wider (zero-padded) vector, then incrementally
7004+
// reducing that (extract half vector and add) until it fits
7005+
// the original type.
7006+
SDValue DAGTypeLegalizer::WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
7007+
SDLoc DL(N);
7008+
EVT VT = N->getValueType(0);
7009+
EVT WideAccVT = TLI.getTypeToTransformTo(*DAG.getContext(),
7010+
N->getOperand(0).getValueType());
7011+
SDValue Zero = DAG.getConstant(0, DL, WideAccVT);
7012+
SDValue MulOp1 = N->getOperand(1);
7013+
SDValue MulOp2 = N->getOperand(2);
7014+
SDValue Acc = DAG.getInsertSubvector(DL, Zero, N->getOperand(0), 0);
7015+
SDValue WidenedRes =
7016+
DAG.getNode(N->getOpcode(), DL, WideAccVT, Acc, MulOp1, MulOp2);
7017+
while (ElementCount::isKnownLT(
7018+
VT.getVectorElementCount(),
7019+
WidenedRes.getValueType().getVectorElementCount())) {
7020+
EVT HalfVT =
7021+
WidenedRes.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
7022+
SDValue Lo = DAG.getExtractSubvector(DL, HalfVT, WidenedRes, 0);
7023+
SDValue Hi = DAG.getExtractSubvector(DL, HalfVT, WidenedRes,
7024+
HalfVT.getVectorMinNumElements());
7025+
WidenedRes = DAG.getNode(ISD::ADD, DL, HalfVT, Lo, Hi);
7026+
}
7027+
return DAG.getInsertSubvector(DL, Zero, WidenedRes, 0);
7028+
}
7029+
69987030
//===----------------------------------------------------------------------===//
69997031
// Widen Vector Operand
70007032
//===----------------------------------------------------------------------===//
@@ -7127,6 +7159,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
71277159
case ISD::VP_REDUCE_FMINIMUM:
71287160
Res = WidenVecOp_VP_REDUCE(N);
71297161
break;
7162+
case ISD::PARTIAL_REDUCE_UMLA:
7163+
case ISD::PARTIAL_REDUCE_SMLA:
7164+
Res = WidenVecOp_PARTIAL_REDUCE_MLA(N);
7165+
break;
71307166
case ISD::VP_CTTZ_ELTS:
71317167
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
71327168
Res = WidenVecOp_VP_CttzElements(N);
@@ -8026,6 +8062,21 @@ SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
80268062
{Source, Mask, N->getOperand(2)}, N->getFlags());
80278063
}
80288064

8065+
SDValue DAGTypeLegalizer::WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
8066+
// Widening of multiplicant operands only. The result and accumulator
8067+
// should already be legal types.
8068+
SDLoc DL(N);
8069+
EVT WideOpVT = TLI.getTypeToTransformTo(*DAG.getContext(),
8070+
N->getOperand(1).getValueType());
8071+
SDValue Acc = N->getOperand(0);
8072+
SDValue WidenedOp1 = DAG.getInsertSubvector(
8073+
DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(1), 0);
8074+
SDValue WidenedOp2 = DAG.getInsertSubvector(
8075+
DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(2), 0);
8076+
return DAG.getNode(N->getOpcode(), DL, Acc.getValueType(), Acc, WidenedOp1,
8077+
WidenedOp2);
8078+
}
8079+
80298080
//===----------------------------------------------------------------------===//
80308081
// Vector Widening Utilities
80318082
//===----------------------------------------------------------------------===//
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
; RUN: llc -mattr=+sve,+dotprod < %s | FileCheck %s
2+
3+
define void @partial_reduce_widen_v1i32_acc_v16i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
4+
%acc = load <1 x i32>, ptr %accptr
5+
%vec = load <16 x i32>, ptr %vecptr
6+
%partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <16 x i32> %vec)
7+
store <1 x i32> %partial.reduce, ptr %resptr
8+
ret void
9+
}
10+
11+
define void @partial_reduce_widen_v3i32_acc_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
12+
%acc = load <3 x i32>, ptr %accptr
13+
%vec = load <12 x i32>, ptr %vecptr
14+
%partial.reduce = call <3 x i32> @llvm.vector.partial.reduce.add(<3 x i32> %acc, <12 x i32> %vec)
15+
store <3 x i32> %partial.reduce, ptr %resptr
16+
ret void
17+
}
18+
19+
define void @partial_reduce_widen_v4i32_acc_v20i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
20+
%acc = load <1 x i32>, ptr %accptr
21+
%vec = load <20 x i32>, ptr %vecptr
22+
%partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <20 x i32> %vec)
23+
store <1 x i32> %partial.reduce, ptr %resptr
24+
ret void
25+
}

0 commit comments

Comments
 (0)