-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[CodeGen] Implement widening for partial.reduce.add #161834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[CodeGen] Implement widening for partial.reduce.add #161834
Conversation
Widening of accumulator/result is done by padding the accumulator with zero elements, performing the partial reduction and then partially reducing the wide vector result (using extract lo/hi + add) into the narrow part of the result vector. Widening of the input vector is done by padding it with zero elements.
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-llvm-selectiondag Author: Sander de Smalen (sdesmalen-arm) ChangesWidening of accumulator/result is done by padding the accumulator with zero elements, performing the partial reduction and then partially reducing the wide vector result (using extract lo/hi + add) into the narrow part of the result vector. Widening of the input vector is done by padding it with zero elements. Full diff: https://github.com/llvm/llvm-project/pull/161834.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 586c3411791f9..c4d69aa48434a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -1117,6 +1117,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecRes_Unary(SDNode *N);
SDValue WidenVecRes_InregOp(SDNode *N);
SDValue WidenVecRes_UnaryOpWithTwoResults(SDNode *N, unsigned ResNo);
+ SDValue WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
void ReplaceOtherWidenResults(SDNode *N, SDNode *WidenNode,
unsigned WidenResNo);
@@ -1152,6 +1153,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecOp_VP_REDUCE(SDNode *N);
SDValue WidenVecOp_ExpOp(SDNode *N);
SDValue WidenVecOp_VP_CttzElements(SDNode *N);
+ SDValue WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N);
/// Helper function to generate a set of operations to perform
/// a vector operation for a wider type.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 87d5453cd98cf..4b409eb5f4c6c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -5136,6 +5136,10 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
if (!unrollExpandedOp())
Res = WidenVecRes_UnaryOpWithTwoResults(N, ResNo);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Res = WidenVecRes_PARTIAL_REDUCE_MLA(N);
+ break;
}
}
@@ -6995,6 +6999,34 @@ SDValue DAGTypeLegalizer::WidenVecRes_STRICT_FSETCC(SDNode *N) {
return DAG.getBuildVector(WidenVT, dl, Scalars);
}
+// Widening the result of a partial reductions is implemented by
+// accumulating into a wider (zero-padded) vector, then incrementally
+// reducing that (extract half vector and add) until it fits
+// the original type.
+SDValue DAGTypeLegalizer::WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ EVT WideAccVT = TLI.getTypeToTransformTo(*DAG.getContext(),
+ N->getOperand(0).getValueType());
+ SDValue Zero = DAG.getConstant(0, DL, WideAccVT);
+ SDValue MulOp1 = N->getOperand(1);
+ SDValue MulOp2 = N->getOperand(2);
+ SDValue Acc = DAG.getInsertSubvector(DL, Zero, N->getOperand(0), 0);
+ SDValue WidenedRes =
+ DAG.getNode(N->getOpcode(), DL, WideAccVT, Acc, MulOp1, MulOp2);
+ while (ElementCount::isKnownLT(
+ VT.getVectorElementCount(),
+ WidenedRes.getValueType().getVectorElementCount())) {
+ EVT HalfVT =
+ WidenedRes.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
+ SDValue Lo = DAG.getExtractSubvector(DL, HalfVT, WidenedRes, 0);
+ SDValue Hi = DAG.getExtractSubvector(DL, HalfVT, WidenedRes,
+ HalfVT.getVectorMinNumElements());
+ WidenedRes = DAG.getNode(ISD::ADD, DL, HalfVT, Lo, Hi);
+ }
+ return DAG.getInsertSubvector(DL, Zero, WidenedRes, 0);
+}
+
//===----------------------------------------------------------------------===//
// Widen Vector Operand
//===----------------------------------------------------------------------===//
@@ -7127,6 +7159,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_REDUCE_FMINIMUM:
Res = WidenVecOp_VP_REDUCE(N);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Res = WidenVecOp_PARTIAL_REDUCE_MLA(N);
+ break;
case ISD::VP_CTTZ_ELTS:
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
Res = WidenVecOp_VP_CttzElements(N);
@@ -8026,6 +8062,21 @@ SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
{Source, Mask, N->getOperand(2)}, N->getFlags());
}
+SDValue DAGTypeLegalizer::WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+ // Widening of multiplicant operands only. The result and accumulator
+ // should already be legal types.
+ SDLoc DL(N);
+ EVT WideOpVT = TLI.getTypeToTransformTo(*DAG.getContext(),
+ N->getOperand(1).getValueType());
+ SDValue Acc = N->getOperand(0);
+ SDValue WidenedOp1 = DAG.getInsertSubvector(
+ DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(1), 0);
+ SDValue WidenedOp2 = DAG.getInsertSubvector(
+ DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(2), 0);
+ return DAG.getNode(N->getOpcode(), DL, Acc.getValueType(), Acc, WidenedOp1,
+ WidenedOp2);
+}
+
//===----------------------------------------------------------------------===//
// Vector Widening Utilities
//===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll b/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
new file mode 100644
index 0000000000000..a6b215b610fca
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
@@ -0,0 +1,25 @@
+; RUN: llc -mattr=+sve,+dotprod < %s | FileCheck %s
+
+define void @partial_reduce_widen_v1i32_acc_v16i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+ %acc = load <1 x i32>, ptr %accptr
+ %vec = load <16 x i32>, ptr %vecptr
+ %partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <16 x i32> %vec)
+ store <1 x i32> %partial.reduce, ptr %resptr
+ ret void
+}
+
+define void @partial_reduce_widen_v3i32_acc_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+ %acc = load <3 x i32>, ptr %accptr
+ %vec = load <12 x i32>, ptr %vecptr
+ %partial.reduce = call <3 x i32> @llvm.vector.partial.reduce.add(<3 x i32> %acc, <12 x i32> %vec)
+ store <3 x i32> %partial.reduce, ptr %resptr
+ ret void
+}
+
+define void @partial_reduce_widen_v4i32_acc_v20i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+ %acc = load <1 x i32>, ptr %accptr
+ %vec = load <20 x i32>, ptr %vecptr
+ %partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <20 x i32> %vec)
+ store <1 x i32> %partial.reduce, ptr %resptr
+ ret void
+}
|
%acc = load <1 x i32>, ptr %accptr | ||
%vec = load <16 x i32>, ptr %vecptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could these tests just take the vectors as parameters and return the vector instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I didn't do that was so that I wouldn't have to pass/return illegal types to the function (the ABI only describes how legal types are passed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense 👍
…vector have a known scalar factor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
if (getTypeAction(MulOpVT) == TargetLowering::TypeWidenVector) { | ||
EVT WideMulVT = GetWidenedVector(MulOp1).getValueType(); | ||
assert(WideMulVT.getVectorElementCount().isKnownMultipleOf(WideAccEC) && | ||
"Widening to a vector with less elements than accumulator?"); | ||
SDValue Zero = DAG.getConstant(0, DL, WideMulVT); | ||
MulOp1 = DAG.getInsertSubvector(DL, Zero, MulOp1, 0); | ||
MulOp2 = DAG.getInsertSubvector(DL, Zero, MulOp2, 0); | ||
} else if (!MulOpEC.isKnownMultipleOf(WideAccEC)) { | ||
assert(getTypeAction(MulOpVT) != TargetLowering::TypeLegal && | ||
"Expected Mul operands to need legalisation"); | ||
EVT WideMulVT = EVT::getVectorVT(*DAG.getContext(), | ||
MulOpVT.getVectorElementType(), WideAccEC); | ||
SDValue Zero = DAG.getConstant(0, DL, WideMulVT); | ||
MulOp1 = DAG.getInsertSubvector(DL, Zero, MulOp1, 0); | ||
MulOp2 = DAG.getInsertSubvector(DL, Zero, MulOp2, 0); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the only difference between these two blocks is their assertions and how they assign `WideMulVT``, so they could share the zero vector construction and subvector insertion as:
bool NeedsWidening = getTypeAction(MulOpVT) == TargetLowering::TypeWidenVector;
bool NarrowMultipleOfWide = MulOpEC.isKnownMultipleOf(WideAccEC);
if (NeedsWidening || !NarrowMultipleOfWide) {
EVT WideMulVT;
if (NeedsWidening) {
assert(...)
...
} else {
assert(...)
...
}
SDValue Zero = ...
MulOp1 = ...
MulOp2 = ...
}
%acc = load <1 x i32>, ptr %accptr | ||
%vec = load <16 x i32>, ptr %vecptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense 👍
I've only partially (pun intended) reviewed the patch because several of the new asserts look more like hope than expectation? so I'm wondering if it would be better to first implement widening that works in all cases and then add more optimal handling for the special cases you care about? For general widening I think something like the follow can work:
This will move all the complication relating to having compatible operand types to the special cases where you can turn the asserts into requirements. An alternative to the above is to reduce the accumulator instead:
This seems simpler and keeps the implicit extension for the mul operands. The chances are the first version will then be necessary for WidenVecOp_PARTIAL_REDUCE_MLA, but I think that will always be the case so it just depends on which looks cleaner once all the bases are covered. |
That's not the case, these asserts are asserting an expectation. When widening the result vector, we must still ensure that the input vector (to be reduced) is a vector with a known multiple of the (to be widened) accumulator. In that case, if both the input and accumulator need widening then those will both need to be (widened to) powers of two and the accumulator can't be widened to something wider than the input vector. The second assert verifies that when widening the accumulator/result and the input does not need widening but is not a multiple of result, that the input must require some other kind of legalisation. I don't believe that either of these two cases can ever happen, hence the assert.
This PR is not particularly trying to optimally handle any cases.
You can't add/accumulate into a poison vector, because if we'd do that the valid part of the result of the partial reduction will be accumulated into those poison lanes and the result would be rubbish. |
The first snippet reduces the mul to a scalar which is then inserted into zero before a lane-wise add with the widened accumulator, so the extra/widened lanes being poison is fine? The second snippet reduces the accumulator to a single-element vector that is used to rewrite the partial reduction, so again no poison? There was a typo (now fixed) in the final step where I should have inserted the partial reduction result into zero instead of poison. Although I guess the problem with the second snippet is it might trigger recursive widening? making it a non-starter. |
Separating out the assert conversation. What I'm worried about is the assumption that legal vector types will have a power-of-two vector length. There's at least one target where this is not the case and thus it's possible to see something like |
Widening of accumulator/result is done by padding the accumulator with zero elements, performing the partial reduction and then partially reducing the wide vector result (using extract lo/hi + add) into the narrow part of the result vector.
Widening of the input vector is done by padding it with zero elements.