@@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
186186 // SIMD-specific configuration
187187 if (Subtarget->hasSIMD128 ()) {
188188
189- // Combine partial.reduce.add before legalization gets confused.
190189 setTargetDAGCombine (ISD::INTRINSIC_WO_CHAIN);
191190
192191 // Combine wide-vector muls, with extend inputs, to extmul_half.
@@ -317,6 +316,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
317316 setOperationAction (ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
318317 setOperationAction (ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
319318 }
319+
320+ // Partial MLA reductions.
321+ for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
322+ setPartialReduceMLAAction (Op, MVT::v4i32, MVT::v16i8, Legal);
323+ setPartialReduceMLAAction (Op, MVT::v4i32, MVT::v8i16, Legal);
324+ }
320325 }
321326
322327 // As a special case, these operators use the type to mean the type to
@@ -416,41 +421,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
416421 return TargetLowering::getPointerMemTy (DL, AS);
417422}
418423
419- bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic (
420- const IntrinsicInst *I) const {
421- if (I->getIntrinsicID () != Intrinsic::vector_partial_reduce_add)
422- return true ;
423-
424- EVT VT = EVT::getEVT (I->getType ());
425- if (VT.getSizeInBits () > 128 )
426- return true ;
427-
428- auto Op1 = I->getOperand (1 );
429-
430- if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
431- unsigned Opcode = InstructionOpcodeToISD (InputInst->getOpcode ());
432- if (Opcode == ISD::MUL) {
433- if (isa<Instruction>(InputInst->getOperand (0 )) &&
434- isa<Instruction>(InputInst->getOperand (1 ))) {
435- // dot only supports signed inputs but also support lowering unsigned.
436- if (cast<Instruction>(InputInst->getOperand (0 ))->getOpcode () !=
437- cast<Instruction>(InputInst->getOperand (1 ))->getOpcode ())
438- return true ;
439-
440- EVT Op1VT = EVT::getEVT (Op1->getType ());
441- if (Op1VT.getVectorElementType () == VT.getVectorElementType () &&
442- ((VT.getVectorElementCount () * 2 ==
443- Op1VT.getVectorElementCount ()) ||
444- (VT.getVectorElementCount () * 4 == Op1VT.getVectorElementCount ())))
445- return false ;
446- }
447- } else if (ISD::isExtOpcode (Opcode)) {
448- return false ;
449- }
450- }
451- return true ;
452- }
453-
454424TargetLowering::AtomicExpansionKind
455425WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR (AtomicRMWInst *AI) const {
456426 // We have wasm instructions for these
@@ -2113,106 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
21132083 MachinePointerInfo (SV));
21142084}
21152085
2116- // Try to lower partial.reduce.add to a dot or fallback to a sequence with
2117- // extmul and adds.
2118- SDValue performLowerPartialReduction (SDNode *N, SelectionDAG &DAG) {
2119- assert (N->getOpcode () == ISD::INTRINSIC_WO_CHAIN);
2120- if (N->getConstantOperandVal (0 ) != Intrinsic::vector_partial_reduce_add)
2121- return SDValue ();
2122-
2123- assert (N->getValueType (0 ) == MVT::v4i32 && " can only support v4i32" );
2124- SDLoc DL (N);
2125-
2126- SDValue Input = N->getOperand (2 );
2127- if (Input->getOpcode () == ISD::MUL) {
2128- SDValue ExtendLHS = Input->getOperand (0 );
2129- SDValue ExtendRHS = Input->getOperand (1 );
2130- assert ((ISD::isExtOpcode (ExtendLHS.getOpcode ()) &&
2131- ISD::isExtOpcode (ExtendRHS.getOpcode ())) &&
2132- " expected widening mul or add" );
2133- assert (ExtendLHS.getOpcode () == ExtendRHS.getOpcode () &&
2134- " expected binop to use the same extend for both operands" );
2135-
2136- SDValue ExtendInLHS = ExtendLHS->getOperand (0 );
2137- SDValue ExtendInRHS = ExtendRHS->getOperand (0 );
2138- bool IsSigned = ExtendLHS->getOpcode () == ISD::SIGN_EXTEND;
2139- unsigned LowOpc =
2140- IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
2141- unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
2142- : WebAssemblyISD::EXTEND_HIGH_U;
2143- SDValue LowLHS;
2144- SDValue LowRHS;
2145- SDValue HighLHS;
2146- SDValue HighRHS;
2147-
2148- auto AssignInputs = [&](MVT VT) {
2149- LowLHS = DAG.getNode (LowOpc, DL, VT, ExtendInLHS);
2150- LowRHS = DAG.getNode (LowOpc, DL, VT, ExtendInRHS);
2151- HighLHS = DAG.getNode (HighOpc, DL, VT, ExtendInLHS);
2152- HighRHS = DAG.getNode (HighOpc, DL, VT, ExtendInRHS);
2153- };
2154-
2155- if (ExtendInLHS->getValueType (0 ) == MVT::v8i16) {
2156- if (IsSigned) {
2157- // i32x4.dot_i16x8_s
2158- SDValue Dot = DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32,
2159- ExtendInLHS, ExtendInRHS);
2160- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Dot);
2161- }
2162-
2163- // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2164- MVT VT = MVT::v4i32;
2165- AssignInputs (VT);
2166- SDValue MulLow = DAG.getNode (ISD::MUL, DL, VT, LowLHS, LowRHS);
2167- SDValue MulHigh = DAG.getNode (ISD::MUL, DL, VT, HighLHS, HighRHS);
2168- SDValue Add = DAG.getNode (ISD::ADD, DL, VT, MulLow, MulHigh);
2169- return DAG.getNode (ISD::ADD, DL, VT, N->getOperand (1 ), Add);
2170- } else {
2171- assert (ExtendInLHS->getValueType (0 ) == MVT::v16i8 &&
2172- " expected v16i8 input types" );
2173- AssignInputs (MVT::v8i16);
2174- // Lower to a wider tree, using twice the operations compared to above.
2175- if (IsSigned) {
2176- // Use two dots
2177- SDValue DotLHS =
2178- DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2179- SDValue DotRHS =
2180- DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2181- SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2182- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2183- }
2184-
2185- SDValue MulLow = DAG.getNode (ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2186- SDValue MulHigh = DAG.getNode (ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2187-
2188- SDValue AddLow = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2189- MVT::v4i32, MulLow);
2190- SDValue AddHigh = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2191- MVT::v4i32, MulHigh);
2192- SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2193- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2194- }
2195- } else {
2196- // Accumulate the input using extadd_pairwise.
2197- assert (ISD::isExtOpcode (Input.getOpcode ()) && " expected extend" );
2198- bool IsSigned = Input->getOpcode () == ISD::SIGN_EXTEND;
2199- unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
2200- : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
2201- SDValue ExtendIn = Input->getOperand (0 );
2202- if (ExtendIn->getValueType (0 ) == MVT::v8i16) {
2203- SDValue Add = DAG.getNode (PairwiseOpc, DL, MVT::v4i32, ExtendIn);
2204- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2205- }
2206-
2207- assert (ExtendIn->getValueType (0 ) == MVT::v16i8 &&
2208- " expected v16i8 input types" );
2209- SDValue Add =
2210- DAG.getNode (PairwiseOpc, DL, MVT::v4i32,
2211- DAG.getNode (PairwiseOpc, DL, MVT::v8i16, ExtendIn));
2212- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2213- }
2214- }
2215-
22162086SDValue WebAssemblyTargetLowering::LowerIntrinsic (SDValue Op,
22172087 SelectionDAG &DAG) const {
22182088 MachineFunction &MF = DAG.getMachineFunction ();
@@ -3683,11 +3553,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
36833553 return performVectorTruncZeroCombine (N, DCI);
36843554 case ISD::TRUNCATE:
36853555 return performTruncateCombine (N, DCI);
3686- case ISD::INTRINSIC_WO_CHAIN: {
3687- if (auto AnyAllCombine = performAnyAllCombine (N, DCI.DAG ))
3688- return AnyAllCombine;
3689- return performLowerPartialReduction (N, DCI.DAG );
3690- }
3556+ case ISD::INTRINSIC_WO_CHAIN:
3557+ return performAnyAllCombine (N, DCI.DAG );
36913558 case ISD::MUL:
36923559 return performMulCombine (N, DCI);
36933560 }
0 commit comments