@@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
186
186
// SIMD-specific configuration
187
187
if (Subtarget->hasSIMD128 ()) {
188
188
189
- // Combine partial.reduce.add before legalization gets confused.
190
189
setTargetDAGCombine (ISD::INTRINSIC_WO_CHAIN);
191
190
192
191
// Combine wide-vector muls, with extend inputs, to extmul_half.
@@ -317,6 +316,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
317
316
setOperationAction (ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
318
317
setOperationAction (ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
319
318
}
319
+
320
+ // Partial MLA reductions.
321
+ for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
322
+ setPartialReduceMLAAction (Op, MVT::v4i32, MVT::v16i8, Legal);
323
+ setPartialReduceMLAAction (Op, MVT::v4i32, MVT::v8i16, Legal);
324
+ }
320
325
}
321
326
322
327
// As a special case, these operators use the type to mean the type to
@@ -416,41 +421,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
416
421
return TargetLowering::getPointerMemTy (DL, AS);
417
422
}
418
423
419
- bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic (
420
- const IntrinsicInst *I) const {
421
- if (I->getIntrinsicID () != Intrinsic::vector_partial_reduce_add)
422
- return true ;
423
-
424
- EVT VT = EVT::getEVT (I->getType ());
425
- if (VT.getSizeInBits () > 128 )
426
- return true ;
427
-
428
- auto Op1 = I->getOperand (1 );
429
-
430
- if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
431
- unsigned Opcode = InstructionOpcodeToISD (InputInst->getOpcode ());
432
- if (Opcode == ISD::MUL) {
433
- if (isa<Instruction>(InputInst->getOperand (0 )) &&
434
- isa<Instruction>(InputInst->getOperand (1 ))) {
435
- // dot only supports signed inputs but also support lowering unsigned.
436
- if (cast<Instruction>(InputInst->getOperand (0 ))->getOpcode () !=
437
- cast<Instruction>(InputInst->getOperand (1 ))->getOpcode ())
438
- return true ;
439
-
440
- EVT Op1VT = EVT::getEVT (Op1->getType ());
441
- if (Op1VT.getVectorElementType () == VT.getVectorElementType () &&
442
- ((VT.getVectorElementCount () * 2 ==
443
- Op1VT.getVectorElementCount ()) ||
444
- (VT.getVectorElementCount () * 4 == Op1VT.getVectorElementCount ())))
445
- return false ;
446
- }
447
- } else if (ISD::isExtOpcode (Opcode)) {
448
- return false ;
449
- }
450
- }
451
- return true ;
452
- }
453
-
454
424
TargetLowering::AtomicExpansionKind
455
425
WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR (AtomicRMWInst *AI) const {
456
426
// We have wasm instructions for these
@@ -2113,106 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
2113
2083
MachinePointerInfo (SV));
2114
2084
}
2115
2085
2116
- // Try to lower partial.reduce.add to a dot or fallback to a sequence with
2117
- // extmul and adds.
2118
- SDValue performLowerPartialReduction (SDNode *N, SelectionDAG &DAG) {
2119
- assert (N->getOpcode () == ISD::INTRINSIC_WO_CHAIN);
2120
- if (N->getConstantOperandVal (0 ) != Intrinsic::vector_partial_reduce_add)
2121
- return SDValue ();
2122
-
2123
- assert (N->getValueType (0 ) == MVT::v4i32 && " can only support v4i32" );
2124
- SDLoc DL (N);
2125
-
2126
- SDValue Input = N->getOperand (2 );
2127
- if (Input->getOpcode () == ISD::MUL) {
2128
- SDValue ExtendLHS = Input->getOperand (0 );
2129
- SDValue ExtendRHS = Input->getOperand (1 );
2130
- assert ((ISD::isExtOpcode (ExtendLHS.getOpcode ()) &&
2131
- ISD::isExtOpcode (ExtendRHS.getOpcode ())) &&
2132
- " expected widening mul or add" );
2133
- assert (ExtendLHS.getOpcode () == ExtendRHS.getOpcode () &&
2134
- " expected binop to use the same extend for both operands" );
2135
-
2136
- SDValue ExtendInLHS = ExtendLHS->getOperand (0 );
2137
- SDValue ExtendInRHS = ExtendRHS->getOperand (0 );
2138
- bool IsSigned = ExtendLHS->getOpcode () == ISD::SIGN_EXTEND;
2139
- unsigned LowOpc =
2140
- IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
2141
- unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
2142
- : WebAssemblyISD::EXTEND_HIGH_U;
2143
- SDValue LowLHS;
2144
- SDValue LowRHS;
2145
- SDValue HighLHS;
2146
- SDValue HighRHS;
2147
-
2148
- auto AssignInputs = [&](MVT VT) {
2149
- LowLHS = DAG.getNode (LowOpc, DL, VT, ExtendInLHS);
2150
- LowRHS = DAG.getNode (LowOpc, DL, VT, ExtendInRHS);
2151
- HighLHS = DAG.getNode (HighOpc, DL, VT, ExtendInLHS);
2152
- HighRHS = DAG.getNode (HighOpc, DL, VT, ExtendInRHS);
2153
- };
2154
-
2155
- if (ExtendInLHS->getValueType (0 ) == MVT::v8i16) {
2156
- if (IsSigned) {
2157
- // i32x4.dot_i16x8_s
2158
- SDValue Dot = DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32,
2159
- ExtendInLHS, ExtendInRHS);
2160
- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Dot);
2161
- }
2162
-
2163
- // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2164
- MVT VT = MVT::v4i32;
2165
- AssignInputs (VT);
2166
- SDValue MulLow = DAG.getNode (ISD::MUL, DL, VT, LowLHS, LowRHS);
2167
- SDValue MulHigh = DAG.getNode (ISD::MUL, DL, VT, HighLHS, HighRHS);
2168
- SDValue Add = DAG.getNode (ISD::ADD, DL, VT, MulLow, MulHigh);
2169
- return DAG.getNode (ISD::ADD, DL, VT, N->getOperand (1 ), Add);
2170
- } else {
2171
- assert (ExtendInLHS->getValueType (0 ) == MVT::v16i8 &&
2172
- " expected v16i8 input types" );
2173
- AssignInputs (MVT::v8i16);
2174
- // Lower to a wider tree, using twice the operations compared to above.
2175
- if (IsSigned) {
2176
- // Use two dots
2177
- SDValue DotLHS =
2178
- DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2179
- SDValue DotRHS =
2180
- DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2181
- SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2182
- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2183
- }
2184
-
2185
- SDValue MulLow = DAG.getNode (ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2186
- SDValue MulHigh = DAG.getNode (ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2187
-
2188
- SDValue AddLow = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2189
- MVT::v4i32, MulLow);
2190
- SDValue AddHigh = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2191
- MVT::v4i32, MulHigh);
2192
- SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2193
- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2194
- }
2195
- } else {
2196
- // Accumulate the input using extadd_pairwise.
2197
- assert (ISD::isExtOpcode (Input.getOpcode ()) && " expected extend" );
2198
- bool IsSigned = Input->getOpcode () == ISD::SIGN_EXTEND;
2199
- unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
2200
- : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
2201
- SDValue ExtendIn = Input->getOperand (0 );
2202
- if (ExtendIn->getValueType (0 ) == MVT::v8i16) {
2203
- SDValue Add = DAG.getNode (PairwiseOpc, DL, MVT::v4i32, ExtendIn);
2204
- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2205
- }
2206
-
2207
- assert (ExtendIn->getValueType (0 ) == MVT::v16i8 &&
2208
- " expected v16i8 input types" );
2209
- SDValue Add =
2210
- DAG.getNode (PairwiseOpc, DL, MVT::v4i32,
2211
- DAG.getNode (PairwiseOpc, DL, MVT::v8i16, ExtendIn));
2212
- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2213
- }
2214
- }
2215
-
2216
2086
SDValue WebAssemblyTargetLowering::LowerIntrinsic (SDValue Op,
2217
2087
SelectionDAG &DAG) const {
2218
2088
MachineFunction &MF = DAG.getMachineFunction ();
@@ -3683,11 +3553,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
3683
3553
return performVectorTruncZeroCombine (N, DCI);
3684
3554
case ISD::TRUNCATE:
3685
3555
return performTruncateCombine (N, DCI);
3686
- case ISD::INTRINSIC_WO_CHAIN: {
3687
- if (auto AnyAllCombine = performAnyAllCombine (N, DCI.DAG ))
3688
- return AnyAllCombine;
3689
- return performLowerPartialReduction (N, DCI.DAG );
3690
- }
3556
+ case ISD::INTRINSIC_WO_CHAIN:
3557
+ return performAnyAllCombine (N, DCI.DAG );
3691
3558
case ISD::MUL:
3692
3559
return performMulCombine (N, DCI);
3693
3560
}
0 commit comments