Skip to content

Commit 156e9b4

Browse files
authored
[WebAssembly] Use partial_reduce_mla ISD nodes (#161184)
Addresssing issue #160847. Move away from combining the intrinsic call and instead lower the ISD nodes, using tablegen for pattern matching.
1 parent ea452c0 commit 156e9b4

File tree

4 files changed

+61
-153
lines changed

4 files changed

+61
-153
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 8 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
186186
// SIMD-specific configuration
187187
if (Subtarget->hasSIMD128()) {
188188

189-
// Combine partial.reduce.add before legalization gets confused.
190189
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
191190

192191
// Combine wide-vector muls, with extend inputs, to extmul_half.
@@ -317,6 +316,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
317316
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
318317
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
319318
}
319+
320+
// Partial MLA reductions.
321+
for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
322+
setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal);
323+
setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v8i16, Legal);
324+
}
320325
}
321326

322327
// As a special case, these operators use the type to mean the type to
@@ -416,41 +421,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
416421
return TargetLowering::getPointerMemTy(DL, AS);
417422
}
418423

419-
bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
420-
const IntrinsicInst *I) const {
421-
if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
422-
return true;
423-
424-
EVT VT = EVT::getEVT(I->getType());
425-
if (VT.getSizeInBits() > 128)
426-
return true;
427-
428-
auto Op1 = I->getOperand(1);
429-
430-
if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
431-
unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
432-
if (Opcode == ISD::MUL) {
433-
if (isa<Instruction>(InputInst->getOperand(0)) &&
434-
isa<Instruction>(InputInst->getOperand(1))) {
435-
// dot only supports signed inputs but also support lowering unsigned.
436-
if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
437-
cast<Instruction>(InputInst->getOperand(1))->getOpcode())
438-
return true;
439-
440-
EVT Op1VT = EVT::getEVT(Op1->getType());
441-
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
442-
((VT.getVectorElementCount() * 2 ==
443-
Op1VT.getVectorElementCount()) ||
444-
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
445-
return false;
446-
}
447-
} else if (ISD::isExtOpcode(Opcode)) {
448-
return false;
449-
}
450-
}
451-
return true;
452-
}
453-
454424
TargetLowering::AtomicExpansionKind
455425
WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
456426
// We have wasm instructions for these
@@ -2113,106 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
21132083
MachinePointerInfo(SV));
21142084
}
21152085

2116-
// Try to lower partial.reduce.add to a dot or fallback to a sequence with
2117-
// extmul and adds.
2118-
SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
2119-
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
2120-
if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add)
2121-
return SDValue();
2122-
2123-
assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
2124-
SDLoc DL(N);
2125-
2126-
SDValue Input = N->getOperand(2);
2127-
if (Input->getOpcode() == ISD::MUL) {
2128-
SDValue ExtendLHS = Input->getOperand(0);
2129-
SDValue ExtendRHS = Input->getOperand(1);
2130-
assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
2131-
ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
2132-
"expected widening mul or add");
2133-
assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
2134-
"expected binop to use the same extend for both operands");
2135-
2136-
SDValue ExtendInLHS = ExtendLHS->getOperand(0);
2137-
SDValue ExtendInRHS = ExtendRHS->getOperand(0);
2138-
bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
2139-
unsigned LowOpc =
2140-
IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
2141-
unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
2142-
: WebAssemblyISD::EXTEND_HIGH_U;
2143-
SDValue LowLHS;
2144-
SDValue LowRHS;
2145-
SDValue HighLHS;
2146-
SDValue HighRHS;
2147-
2148-
auto AssignInputs = [&](MVT VT) {
2149-
LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS);
2150-
LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS);
2151-
HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
2152-
HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
2153-
};
2154-
2155-
if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
2156-
if (IsSigned) {
2157-
// i32x4.dot_i16x8_s
2158-
SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
2159-
ExtendInLHS, ExtendInRHS);
2160-
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
2161-
}
2162-
2163-
// (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2164-
MVT VT = MVT::v4i32;
2165-
AssignInputs(VT);
2166-
SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
2167-
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
2168-
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
2169-
return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
2170-
} else {
2171-
assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
2172-
"expected v16i8 input types");
2173-
AssignInputs(MVT::v8i16);
2174-
// Lower to a wider tree, using twice the operations compared to above.
2175-
if (IsSigned) {
2176-
// Use two dots
2177-
SDValue DotLHS =
2178-
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2179-
SDValue DotRHS =
2180-
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2181-
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2182-
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
2183-
}
2184-
2185-
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2186-
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2187-
2188-
SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2189-
MVT::v4i32, MulLow);
2190-
SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2191-
MVT::v4i32, MulHigh);
2192-
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2193-
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
2194-
}
2195-
} else {
2196-
// Accumulate the input using extadd_pairwise.
2197-
assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
2198-
bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
2199-
unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
2200-
: WebAssemblyISD::EXT_ADD_PAIRWISE_U;
2201-
SDValue ExtendIn = Input->getOperand(0);
2202-
if (ExtendIn->getValueType(0) == MVT::v8i16) {
2203-
SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
2204-
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
2205-
}
2206-
2207-
assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
2208-
"expected v16i8 input types");
2209-
SDValue Add =
2210-
DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
2211-
DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
2212-
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
2213-
}
2214-
}
2215-
22162086
SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
22172087
SelectionDAG &DAG) const {
22182088
MachineFunction &MF = DAG.getMachineFunction();
@@ -3683,11 +3553,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
36833553
return performVectorTruncZeroCombine(N, DCI);
36843554
case ISD::TRUNCATE:
36853555
return performTruncateCombine(N, DCI);
3686-
case ISD::INTRINSIC_WO_CHAIN: {
3687-
if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG))
3688-
return AnyAllCombine;
3689-
return performLowerPartialReduction(N, DCI.DAG);
3690-
}
3556+
case ISD::INTRINSIC_WO_CHAIN:
3557+
return performAnyAllCombine(N, DCI.DAG);
36913558
case ISD::MUL:
36923559
return performMulCombine(N, DCI);
36933560
}

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ class WebAssemblyTargetLowering final : public TargetLowering {
4545
/// right decision when generating code for different targets.
4646
const WebAssemblySubtarget *Subtarget;
4747

48-
bool
49-
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
5048
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
5149
bool shouldScalarizeBinop(SDValue VecOp) const override;
5250
FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
@@ -89,8 +87,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
8987
bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF,
9088
bool isVarArg,
9189
const SmallVectorImpl<ISD::OutputArg> &Outs,
92-
LLVMContext &Context,
93-
const Type *RetTy) const override;
90+
LLVMContext &Context, const Type *RetTy) const override;
9491
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
9592
const SmallVectorImpl<ISD::OutputArg> &Outs,
9693
const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,

llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,51 @@ def : Pat<(v2f64 (extloadv2f32 (i64 I64:$addr))),
15041504
defm Q15MULR_SAT_S :
15051505
SIMDBinary<I16x8, int_wasm_q15mulr_sat_signed, "q15mulr_sat_s", 0x82>;
15061506

1507+
//===----------------------------------------------------------------------===//
1508+
// Partial reductions, using: dot, extmul and extadd_pairwise
1509+
//===----------------------------------------------------------------------===//
1510+
// MLA: v8i16 -> v4i32
1511+
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs),
1512+
(v8i16 V128:$rhs))),
1513+
(ADD_I32x4 (DOT $lhs, $rhs), $acc)>;
1514+
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$lhs),
1515+
(v8i16 V128:$rhs))),
1516+
(ADD_I32x4 (ADD_I32x4 (EXTMUL_LOW_U_I32x4 $lhs, $rhs),
1517+
(EXTMUL_HIGH_U_I32x4 $lhs, $rhs)),
1518+
$acc)>;
1519+
// MLA: v16i8 -> v4i32
1520+
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$lhs),
1521+
(v16i8 V128:$rhs))),
1522+
(ADD_I32x4 (ADD_I32x4 (DOT (extend_low_s_I16x8 $lhs),
1523+
(extend_low_s_I16x8 $rhs)),
1524+
(DOT (extend_high_s_I16x8 $lhs),
1525+
(extend_high_s_I16x8 $rhs))),
1526+
$acc)>;
1527+
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$lhs),
1528+
(v16i8 V128:$rhs))),
1529+
(ADD_I32x4 (ADD_I32x4 (extadd_pairwise_u_I32x4 (EXTMUL_LOW_U_I16x8 $lhs, $rhs)),
1530+
(extadd_pairwise_u_I32x4 (EXTMUL_HIGH_U_I16x8 $lhs, $rhs))),
1531+
$acc)>;
1532+
1533+
// Accumulate: v8i16 -> v4i32
1534+
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in),
1535+
(I16x8.splat (i32 1)))),
1536+
(ADD_I32x4 (extadd_pairwise_s_I32x4 $in), $acc)>;
1537+
1538+
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in),
1539+
(I16x8.splat (i32 1)))),
1540+
(ADD_I32x4 (extadd_pairwise_u_I32x4 $in), $acc)>;
1541+
1542+
// Accumulate: v16i8 -> v4i32
1543+
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$in),
1544+
(I8x16.splat (i32 1)))),
1545+
(ADD_I32x4 (extadd_pairwise_s_I32x4 (extadd_pairwise_s_I16x8 $in)),
1546+
$acc)>;
1547+
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$in),
1548+
(I8x16.splat (i32 1)))),
1549+
(ADD_I32x4 (extadd_pairwise_u_I32x4 (extadd_pairwise_u_I16x8 $in)),
1550+
$acc)>;
1551+
15071552
//===----------------------------------------------------------------------===//
15081553
// Relaxed swizzle
15091554
//===----------------------------------------------------------------------===//

llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ define hidden i32 @accumulate_add_u8_u8(ptr noundef readonly %a, ptr noundef re
1919
; MAX-BANDWIDTH: v128.load
2020
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u
2121
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
22-
; MAX-BANDWIDTH: i32x4.add
2322
; MAX-BANDWIDTH: v128.load
2423
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u
2524
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
2625
; MAX-BANDWIDTH: i32x4.add
26+
; MAX-BANDWIDTH: i32x4.add
2727

2828
entry:
2929
%cmp8.not = icmp eq i32 %N, 0
@@ -65,11 +65,11 @@ define hidden i32 @accumulate_add_s8_s8(ptr noundef readonly %a, ptr noundef re
6565
; MAX-BANDWIDTH: v128.load
6666
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
6767
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
68-
; MAX-BANDWIDTH: i32x4.add
6968
; MAX-BANDWIDTH: v128.load
7069
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
7170
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
7271
; MAX-BANDWIDTH: i32x4.add
72+
; MAX-BANDWIDTH: i32x4.add
7373
entry:
7474
%cmp8.not = icmp eq i32 %N, 0
7575
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body
@@ -108,12 +108,11 @@ define hidden i32 @accumulate_add_s8_u8(ptr noundef readonly %a, ptr noundef re
108108

109109
; MAX-BANDWIDTH: loop
110110
; MAX-BANDWIDTH: v128.load
111-
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
112-
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
113-
; MAX-BANDWIDTH: i32x4.add
114-
; MAX-BANDWIDTH: v128.load
115111
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u
116112
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
113+
; MAX-BANDWIDTH: v128.load
114+
; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
115+
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
117116
; MAX-BANDWIDTH: i32x4.add
118117
entry:
119118
%cmp8.not = icmp eq i32 %N, 0
@@ -363,10 +362,10 @@ define hidden i32 @accumulate_add_u16_u16(ptr noundef readonly %a, ptr noundef
363362
; MAX-BANDWIDTH: loop
364363
; MAX-BANDWIDTH: v128.load
365364
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
366-
; MAX-BANDWIDTH: i32x4.add
367365
; MAX-BANDWIDTH: v128.load
368366
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
369367
; MAX-BANDWIDTH: i32x4.add
368+
; MAX-BANDWIDTH: i32x4.add
370369
entry:
371370
%cmp8.not = icmp eq i32 %N, 0
372371
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body
@@ -402,10 +401,10 @@ define hidden i32 @accumulate_add_s16_s16(ptr noundef readonly %a, ptr noundef
402401
; MAX-BANDWIDTH: loop
403402
; MAX-BANDWIDTH: v128.load
404403
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
405-
; MAX-BANDWIDTH: i32x4.add
406404
; MAX-BANDWIDTH: v128.load
407405
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
408406
; MAX-BANDWIDTH: i32x4.add
407+
; MAX-BANDWIDTH: i32x4.add
409408
entry:
410409
%cmp8.not = icmp eq i32 %N, 0
411410
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body

0 commit comments

Comments
 (0)