Skip to content

Commit 586c0ad

Browse files
authored
[WebAssembly] Support partial-reduce accumulator (#158060)
We currently only support partial.reduce.add in the case where we are performing a multiply-accumulate. Now add support for any partial reduction where the input is being extended, where we can take advantage of extadd_pairwise.
1 parent 13daa1e commit 586c0ad

File tree

5 files changed

+736
-90
lines changed

5 files changed

+736
-90
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISD.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ HANDLE_NODETYPE(BR_IF)
2828
HANDLE_NODETYPE(BR_TABLE)
2929
HANDLE_NODETYPE(DOT)
3030
HANDLE_NODETYPE(EXT_ADD_PAIRWISE_U)
31+
HANDLE_NODETYPE(EXT_ADD_PAIRWISE_S)
3132
HANDLE_NODETYPE(SHUFFLE)
3233
HANDLE_NODETYPE(SWIZZLE)
3334
HANDLE_NODETYPE(VEC_SHL)

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 101 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -422,24 +422,30 @@ bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
422422
return true;
423423

424424
EVT VT = EVT::getEVT(I->getType());
425+
if (VT.getSizeInBits() > 128)
426+
return true;
427+
425428
auto Op1 = I->getOperand(1);
426429

427430
if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
428-
if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL)
429-
return true;
430-
431-
if (isa<Instruction>(InputInst->getOperand(0)) &&
432-
isa<Instruction>(InputInst->getOperand(1))) {
433-
// dot only supports signed inputs but also support lowering unsigned.
434-
if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
435-
cast<Instruction>(InputInst->getOperand(1))->getOpcode())
436-
return true;
437-
438-
EVT Op1VT = EVT::getEVT(Op1->getType());
439-
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
440-
((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) ||
441-
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
442-
return false;
431+
unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
432+
if (Opcode == ISD::MUL) {
433+
if (isa<Instruction>(InputInst->getOperand(0)) &&
434+
isa<Instruction>(InputInst->getOperand(1))) {
435+
// dot only supports signed inputs but also support lowering unsigned.
436+
if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
437+
cast<Instruction>(InputInst->getOperand(1))->getOpcode())
438+
return true;
439+
440+
EVT Op1VT = EVT::getEVT(Op1->getType());
441+
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
442+
((VT.getVectorElementCount() * 2 ==
443+
Op1VT.getVectorElementCount()) ||
444+
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
445+
return false;
446+
}
447+
} else if (ISD::isExtOpcode(Opcode)) {
448+
return false;
443449
}
444450
}
445451
return true;
@@ -2117,77 +2123,93 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
21172123

21182124
assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
21192125
SDLoc DL(N);
2120-
SDValue Mul = N->getOperand(2);
2121-
assert(Mul->getOpcode() == ISD::MUL && "expected mul input");
2122-
2123-
SDValue ExtendLHS = Mul->getOperand(0);
2124-
SDValue ExtendRHS = Mul->getOperand(1);
2125-
assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
2126-
ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
2127-
"expected widening mul");
2128-
assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
2129-
"expected mul to use the same extend for both operands");
2130-
2131-
SDValue ExtendInLHS = ExtendLHS->getOperand(0);
2132-
SDValue ExtendInRHS = ExtendRHS->getOperand(0);
2133-
bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
2134-
2135-
if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
2136-
if (IsSigned) {
2137-
// i32x4.dot_i16x8_s
2138-
SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
2139-
ExtendInLHS, ExtendInRHS);
2140-
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
2141-
}
21422126

2143-
unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
2144-
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
2127+
SDValue Input = N->getOperand(2);
2128+
if (Input->getOpcode() == ISD::MUL) {
2129+
SDValue ExtendLHS = Input->getOperand(0);
2130+
SDValue ExtendRHS = Input->getOperand(1);
2131+
assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
2132+
ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
2133+
"expected widening mul or add");
2134+
assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
2135+
"expected binop to use the same extend for both operands");
2136+
2137+
SDValue ExtendInLHS = ExtendLHS->getOperand(0);
2138+
SDValue ExtendInRHS = ExtendRHS->getOperand(0);
2139+
bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
2140+
unsigned LowOpc =
2141+
IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
2142+
unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
2143+
: WebAssemblyISD::EXTEND_HIGH_U;
2144+
SDValue LowLHS;
2145+
SDValue LowRHS;
2146+
SDValue HighLHS;
2147+
SDValue HighRHS;
2148+
2149+
auto AssignInputs = [&](MVT VT) {
2150+
LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS);
2151+
LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS);
2152+
HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
2153+
HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
2154+
};
21452155

2146-
// (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2147-
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS);
2148-
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS);
2149-
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS);
2150-
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS);
2156+
if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
2157+
if (IsSigned) {
2158+
// i32x4.dot_i16x8_s
2159+
SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
2160+
ExtendInLHS, ExtendInRHS);
2161+
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
2162+
}
21512163

2152-
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
2153-
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
2154-
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
2155-
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
2164+
// (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2165+
MVT VT = MVT::v4i32;
2166+
AssignInputs(VT);
2167+
SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
2168+
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
2169+
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
2170+
return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
2171+
} else {
2172+
assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
2173+
"expected v16i8 input types");
2174+
AssignInputs(MVT::v8i16);
2175+
// Lower to a wider tree, using twice the operations compared to above.
2176+
if (IsSigned) {
2177+
// Use two dots
2178+
SDValue DotLHS =
2179+
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2180+
SDValue DotRHS =
2181+
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2182+
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2183+
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
2184+
}
2185+
2186+
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2187+
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2188+
2189+
SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2190+
MVT::v4i32, MulLow);
2191+
SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2192+
MVT::v4i32, MulHigh);
2193+
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2194+
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
2195+
}
21562196
} else {
2157-
assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
2158-
"expected v16i8 input types");
2159-
// Lower to a wider tree, using twice the operations compared to above.
2160-
if (IsSigned) {
2161-
// Use two dots
2162-
unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
2163-
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
2164-
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
2165-
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
2166-
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
2167-
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
2168-
SDValue DotLHS =
2169-
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2170-
SDValue DotRHS =
2171-
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2172-
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2197+
// Accumulate the input using extadd_pairwise.
2198+
assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
2199+
bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
2200+
unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
2201+
: WebAssemblyISD::EXT_ADD_PAIRWISE_U;
2202+
SDValue ExtendIn = Input->getOperand(0);
2203+
if (ExtendIn->getValueType(0) == MVT::v8i16) {
2204+
SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
21732205
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
21742206
}
21752207

2176-
unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
2177-
unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
2178-
SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
2179-
SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
2180-
SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
2181-
SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
2182-
2183-
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2184-
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2185-
2186-
SDValue AddLow =
2187-
DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, MVT::v4i32, MulLow);
2188-
SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2189-
MVT::v4i32, MulHigh);
2190-
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2208+
assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
2209+
"expected v16i8 input types");
2210+
SDValue Add =
2211+
DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
2212+
DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
21912213
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
21922214
}
21932215
}

llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,12 +1454,13 @@ def : Pat<(t1.vt (bitconvert (t2.vt V128:$v))), (t1.vt V128:$v)>;
14541454

14551455
// Extended pairwise addition
14561456
def extadd_pairwise_u : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_U", extend_t>;
1457+
def extadd_pairwise_s : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_S", extend_t>;
14571458

1458-
defm "" : SIMDConvert<I16x8, I8x16, int_wasm_extadd_pairwise_signed,
1459+
defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_s,
14591460
"extadd_pairwise_i8x16_s", 0x7c>;
14601461
defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_u,
14611462
"extadd_pairwise_i8x16_u", 0x7d>;
1462-
defm "" : SIMDConvert<I32x4, I16x8, int_wasm_extadd_pairwise_signed,
1463+
defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_s,
14631464
"extadd_pairwise_i16x8_s", 0x7e>;
14641465
defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_u,
14651466
"extadd_pairwise_i16x8_u", 0x7f>;
@@ -1468,6 +1469,10 @@ def : Pat<(v4i32 (int_wasm_extadd_pairwise_unsigned (v8i16 V128:$in))),
14681469
(extadd_pairwise_u_I32x4 V128:$in)>;
14691470
def : Pat<(v8i16 (int_wasm_extadd_pairwise_unsigned (v16i8 V128:$in))),
14701471
(extadd_pairwise_u_I16x8 V128:$in)>;
1472+
def : Pat<(v4i32 (int_wasm_extadd_pairwise_signed (v8i16 V128:$in))),
1473+
(extadd_pairwise_s_I32x4 V128:$in)>;
1474+
def : Pat<(v8i16 (int_wasm_extadd_pairwise_signed (v16i8 V128:$in))),
1475+
(extadd_pairwise_s_I16x8 V128:$in)>;
14711476

14721477
// f64x2 <-> f32x4 conversions
14731478
def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;

llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -316,31 +316,40 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
316316
if (CostKind != TTI::TCK_RecipThroughput)
317317
return Invalid;
318318

319-
InstructionCost Cost(TTI::TCC_Basic);
319+
if (Opcode != Instruction::Add)
320+
return Invalid;
321+
322+
EVT AccumEVT = EVT::getEVT(AccumType);
323+
// TODO: Add i64 accumulator.
324+
if (AccumEVT != MVT::i32)
325+
return Invalid;
320326

321327
// Possible options:
322328
// - i16x8.extadd_pairwise_i8x16_sx
323329
// - i32x4.extadd_pairwise_i16x8_sx
324330
// - i32x4.dot_i16x8_s
325331
// Only try to support dot, for now.
326332

327-
if (Opcode != Instruction::Add)
333+
EVT InputEVT = EVT::getEVT(InputTypeA);
334+
if (!((InputEVT == MVT::i16 && VF.getFixedValue() == 8) ||
335+
(InputEVT == MVT::i8 && VF.getFixedValue() == 16))) {
328336
return Invalid;
337+
}
329338

330-
if (!BinOp || *BinOp != Instruction::Mul)
339+
if (OpAExtend == TTI::PR_None)
331340
return Invalid;
332341

333-
if (InputTypeA != InputTypeB)
334-
return Invalid;
342+
InstructionCost Cost(TTI::TCC_Basic);
343+
if (!BinOp)
344+
return Cost;
335345

336346
if (OpAExtend != OpBExtend)
337347
return Invalid;
338348

339-
EVT InputEVT = EVT::getEVT(InputTypeA);
340-
EVT AccumEVT = EVT::getEVT(AccumType);
349+
if (*BinOp != Instruction::Mul)
350+
return Invalid;
341351

342-
// TODO: Add i64 accumulator.
343-
if (AccumEVT != MVT::i32)
352+
if (InputTypeA != InputTypeB)
344353
return Invalid;
345354

346355
// Signed inputs can lower to dot

0 commit comments

Comments
 (0)