-
Notifications
You must be signed in to change notification settings - Fork 15k
[WebAssembly] Support partial-reduce accumulator #158060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
We currently only support partial.reduce.add in the case where we are performing a multiply-accumulate. Now add support for any partial reduction where the input is being extended, where we can take advantage of extadd_pairwise.
|
@llvm/pr-subscribers-backend-webassembly Author: Sam Parker (sparker-arm) ChangesWe currently only support partial.reduce.add in the case where we are performing a multiply-accumulate. Now add support for any partial reduction where the input is being extended, where we can take advantage of extadd_pairwise. Full diff: https://github.com/llvm/llvm-project/pull/158060.diff 4 Files Affected:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
index 1eae3586d16b8..23108e429eda8 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
@@ -28,6 +28,7 @@ HANDLE_NODETYPE(BR_IF)
HANDLE_NODETYPE(BR_TABLE)
HANDLE_NODETYPE(DOT)
HANDLE_NODETYPE(EXT_ADD_PAIRWISE_U)
+HANDLE_NODETYPE(EXT_ADD_PAIRWISE_S)
HANDLE_NODETYPE(SHUFFLE)
HANDLE_NODETYPE(SWIZZLE)
HANDLE_NODETYPE(VEC_SHL)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index fe100dab427ef..aea27ba32d37e 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -422,24 +422,30 @@ bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
EVT VT = EVT::getEVT(I->getType());
+ if (VT.getSizeInBits() > 128)
+ return true;
+
auto Op1 = I->getOperand(1);
if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
- if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL)
- return true;
-
- if (isa<Instruction>(InputInst->getOperand(0)) &&
- isa<Instruction>(InputInst->getOperand(1))) {
- // dot only supports signed inputs but also support lowering unsigned.
- if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
- cast<Instruction>(InputInst->getOperand(1))->getOpcode())
- return true;
-
- EVT Op1VT = EVT::getEVT(Op1->getType());
- if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
- ((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) ||
- (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
- return false;
+ unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
+ if (Opcode == ISD::MUL) {
+ if (isa<Instruction>(InputInst->getOperand(0)) &&
+ isa<Instruction>(InputInst->getOperand(1))) {
+ // dot only supports signed inputs but also support lowering unsigned.
+ if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
+ cast<Instruction>(InputInst->getOperand(1))->getOpcode())
+ return true;
+
+ EVT Op1VT = EVT::getEVT(Op1->getType());
+ if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
+ ((VT.getVectorElementCount() * 2 ==
+ Op1VT.getVectorElementCount()) ||
+ (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
+ return false;
+ }
+ } else if (ISD::isExtOpcode(Opcode)) {
+ return false;
}
}
return true;
@@ -2117,77 +2123,93 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
SDLoc DL(N);
- SDValue Mul = N->getOperand(2);
- assert(Mul->getOpcode() == ISD::MUL && "expected mul input");
-
- SDValue ExtendLHS = Mul->getOperand(0);
- SDValue ExtendRHS = Mul->getOperand(1);
- assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
- ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
- "expected widening mul");
- assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
- "expected mul to use the same extend for both operands");
-
- SDValue ExtendInLHS = ExtendLHS->getOperand(0);
- SDValue ExtendInRHS = ExtendRHS->getOperand(0);
- bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
-
- if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
- if (IsSigned) {
- // i32x4.dot_i16x8_s
- SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
- ExtendInLHS, ExtendInRHS);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
- }
- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
+ SDValue Input = N->getOperand(2);
+ if (Input->getOpcode() == ISD::MUL) {
+ SDValue ExtendLHS = Input->getOperand(0);
+ SDValue ExtendRHS = Input->getOperand(1);
+ assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
+ ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
+ "expected widening mul or add");
+ assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
+ "expected binop to use the same extend for both operands");
+
+ SDValue ExtendInLHS = ExtendLHS->getOperand(0);
+ SDValue ExtendInRHS = ExtendRHS->getOperand(0);
+ bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
+ unsigned LowOpc =
+ IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
+ unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
+ : WebAssemblyISD::EXTEND_HIGH_U;
+ SDValue LowLHS;
+ SDValue LowRHS;
+ SDValue HighLHS;
+ SDValue HighRHS;
+
+ auto AssignInputs = [&](MVT VT) {
+ LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS);
+ LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS);
+ HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
+ HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
+ };
- // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
- SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS);
- SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS);
- SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS);
- SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS);
+ if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
+ if (IsSigned) {
+ // i32x4.dot_i16x8_s
+ SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
+ ExtendInLHS, ExtendInRHS);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
+ }
- SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
- SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
+ MVT VT = MVT::v4i32;
+ AssignInputs(VT);
+ SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
+ SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
+ return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
+ } else {
+ assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
+ "expected v16i8 input types");
+ AssignInputs(MVT::v8i16);
+ // Lower to a wider tree, using twice the operations compared to above.
+ if (IsSigned) {
+ // Use two dots
+ SDValue DotLHS =
+ DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
+ SDValue DotRHS =
+ DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ }
+
+ SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+ SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
+
+ SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
+ MVT::v4i32, MulLow);
+ SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
+ MVT::v4i32, MulHigh);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ }
} else {
- assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
- "expected v16i8 input types");
- // Lower to a wider tree, using twice the operations compared to above.
- if (IsSigned) {
- // Use two dots
- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
- SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
- SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
- SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
- SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
- SDValue DotLHS =
- DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
- SDValue DotRHS =
- DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
+ // Accumulate the input using extadd_pairwise.
+ assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
+ bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
+ unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
+ : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
+ SDValue ExtendIn = Input->getOperand(0);
+ if (ExtendIn->getValueType(0) == MVT::v8i16) {
+ SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
- SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
- SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
- SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
- SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
-
- SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
- SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
-
- SDValue AddLow =
- DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, MVT::v4i32, MulLow);
- SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
- MVT::v4i32, MulHigh);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
+ assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
+ "expected v16i8 input types");
+ SDValue Add =
+ DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
+ DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 3c26b453c4482..d8948ad2df037 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1454,12 +1454,13 @@ def : Pat<(t1.vt (bitconvert (t2.vt V128:$v))), (t1.vt V128:$v)>;
// Extended pairwise addition
def extadd_pairwise_u : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_U", extend_t>;
+def extadd_pairwise_s : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_S", extend_t>;
-defm "" : SIMDConvert<I16x8, I8x16, int_wasm_extadd_pairwise_signed,
+defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_s,
"extadd_pairwise_i8x16_s", 0x7c>;
defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_u,
"extadd_pairwise_i8x16_u", 0x7d>;
-defm "" : SIMDConvert<I32x4, I16x8, int_wasm_extadd_pairwise_signed,
+defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_s,
"extadd_pairwise_i16x8_s", 0x7e>;
defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_u,
"extadd_pairwise_i16x8_u", 0x7f>;
@@ -1468,6 +1469,10 @@ def : Pat<(v4i32 (int_wasm_extadd_pairwise_unsigned (v8i16 V128:$in))),
(extadd_pairwise_u_I32x4 V128:$in)>;
def : Pat<(v8i16 (int_wasm_extadd_pairwise_unsigned (v16i8 V128:$in))),
(extadd_pairwise_u_I16x8 V128:$in)>;
+def : Pat<(v4i32 (int_wasm_extadd_pairwise_signed (v8i16 V128:$in))),
+ (extadd_pairwise_s_I32x4 V128:$in)>;
+def : Pat<(v8i16 (int_wasm_extadd_pairwise_signed (v16i8 V128:$in))),
+ (extadd_pairwise_s_I16x8 V128:$in)>;
// f64x2 <-> f32x4 conversions
def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 0eefd3e2b3500..92a9812df2127 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -316,7 +316,13 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
if (CostKind != TTI::TCK_RecipThroughput)
return Invalid;
- InstructionCost Cost(TTI::TCC_Basic);
+ if (Opcode != Instruction::Add)
+ return Invalid;
+
+ EVT AccumEVT = EVT::getEVT(AccumType);
+ // TODO: Add i64 accumulator.
+ if (AccumEVT != MVT::i32)
+ return Invalid;
// Possible options:
// - i16x8.extadd_pairwise_i8x16_sx
@@ -324,23 +330,26 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
// - i32x4.dot_i16x8_s
// Only try to support dot, for now.
- if (Opcode != Instruction::Add)
+ EVT InputEVT = EVT::getEVT(InputTypeA);
+ if (!((InputEVT == MVT::i16 && VF.getFixedValue() == 8) ||
+ (InputEVT == MVT::i8 && VF.getFixedValue() == 16))) {
return Invalid;
+ }
- if (!BinOp || *BinOp != Instruction::Mul)
+ if (OpAExtend == TTI::PR_None)
return Invalid;
- if (InputTypeA != InputTypeB)
- return Invalid;
+ InstructionCost Cost(TTI::TCC_Basic);
+ if (!BinOp)
+ return Cost;
if (OpAExtend != OpBExtend)
return Invalid;
- EVT InputEVT = EVT::getEVT(InputTypeA);
- EVT AccumEVT = EVT::getEVT(AccumType);
+ if (*BinOp != Instruction::Mul)
+ return Invalid;
- // TODO: Add i64 accumulator.
- if (AccumEVT != MVT::i32)
+ if (InputTypeA != InputTypeB)
return Invalid;
// Signed inputs can lower to dot
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks good.
One question on the context and testing; the TTI change causes the loop vectorizer to vectorize where it wouldn't before, and I assume that the default (CHECK) cases cover that (i.e. without this change we wouldn't get the intrinsics in the IR). what's the purpose of the MAX-BANDWIDTH case? how would real code end up covering this case?
|
The TTI change only affects whether the partial.reduce intrinsic is used. MAX-BANDWIDTH is when we're forcing the vectorizer to vectorize based on the memory type, instead of the widest type and, AFAICT, these partial reductions will not happen unless we're optimising for bandwidth. The test uses a vectorizer option, but it is also available as a TTI hook, which I'm exploring whether we could/should have it enabled. |
We currently only support partial.reduce.add in the case where we are performing a multiply-accumulate. Now add support for any partial reduction where the input is being extended, where we can take advantage of extadd_pairwise.