@@ -422,24 +422,30 @@ bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
422422 return true ;
423423
424424 EVT VT = EVT::getEVT (I->getType ());
425+ if (VT.getSizeInBits () > 128 )
426+ return true ;
427+
425428 auto Op1 = I->getOperand (1 );
426429
427430 if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
428- if (InstructionOpcodeToISD (InputInst->getOpcode ()) != ISD::MUL)
429- return true ;
430-
431- if (isa<Instruction>(InputInst->getOperand (0 )) &&
432- isa<Instruction>(InputInst->getOperand (1 ))) {
433- // dot only supports signed inputs but also support lowering unsigned.
434- if (cast<Instruction>(InputInst->getOperand (0 ))->getOpcode () !=
435- cast<Instruction>(InputInst->getOperand (1 ))->getOpcode ())
436- return true ;
437-
438- EVT Op1VT = EVT::getEVT (Op1->getType ());
439- if (Op1VT.getVectorElementType () == VT.getVectorElementType () &&
440- ((VT.getVectorElementCount () * 2 == Op1VT.getVectorElementCount ()) ||
441- (VT.getVectorElementCount () * 4 == Op1VT.getVectorElementCount ())))
442- return false ;
431+ unsigned Opcode = InstructionOpcodeToISD (InputInst->getOpcode ());
432+ if (Opcode == ISD::MUL) {
433+ if (isa<Instruction>(InputInst->getOperand (0 )) &&
434+ isa<Instruction>(InputInst->getOperand (1 ))) {
435+ // dot only supports signed inputs but also support lowering unsigned.
436+ if (cast<Instruction>(InputInst->getOperand (0 ))->getOpcode () !=
437+ cast<Instruction>(InputInst->getOperand (1 ))->getOpcode ())
438+ return true ;
439+
440+ EVT Op1VT = EVT::getEVT (Op1->getType ());
441+ if (Op1VT.getVectorElementType () == VT.getVectorElementType () &&
442+ ((VT.getVectorElementCount () * 2 ==
443+ Op1VT.getVectorElementCount ()) ||
444+ (VT.getVectorElementCount () * 4 == Op1VT.getVectorElementCount ())))
445+ return false ;
446+ }
447+ } else if (ISD::isExtOpcode (Opcode)) {
448+ return false ;
443449 }
444450 }
445451 return true ;
@@ -2117,77 +2123,93 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
21172123
21182124 assert (N->getValueType (0 ) == MVT::v4i32 && " can only support v4i32" );
21192125 SDLoc DL (N);
2120- SDValue Mul = N->getOperand (2 );
2121- assert (Mul->getOpcode () == ISD::MUL && " expected mul input" );
2122-
2123- SDValue ExtendLHS = Mul->getOperand (0 );
2124- SDValue ExtendRHS = Mul->getOperand (1 );
2125- assert ((ISD::isExtOpcode (ExtendLHS.getOpcode ()) &&
2126- ISD::isExtOpcode (ExtendRHS.getOpcode ())) &&
2127- " expected widening mul" );
2128- assert (ExtendLHS.getOpcode () == ExtendRHS.getOpcode () &&
2129- " expected mul to use the same extend for both operands" );
2130-
2131- SDValue ExtendInLHS = ExtendLHS->getOperand (0 );
2132- SDValue ExtendInRHS = ExtendRHS->getOperand (0 );
2133- bool IsSigned = ExtendLHS->getOpcode () == ISD::SIGN_EXTEND;
2134-
2135- if (ExtendInLHS->getValueType (0 ) == MVT::v8i16) {
2136- if (IsSigned) {
2137- // i32x4.dot_i16x8_s
2138- SDValue Dot = DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32,
2139- ExtendInLHS, ExtendInRHS);
2140- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Dot);
2141- }
21422126
2143- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
2144- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
2127+ SDValue Input = N->getOperand (2 );
2128+ if (Input->getOpcode () == ISD::MUL) {
2129+ SDValue ExtendLHS = Input->getOperand (0 );
2130+ SDValue ExtendRHS = Input->getOperand (1 );
2131+ assert ((ISD::isExtOpcode (ExtendLHS.getOpcode ()) &&
2132+ ISD::isExtOpcode (ExtendRHS.getOpcode ())) &&
2133+ " expected widening mul or add" );
2134+ assert (ExtendLHS.getOpcode () == ExtendRHS.getOpcode () &&
2135+ " expected binop to use the same extend for both operands" );
2136+
2137+ SDValue ExtendInLHS = ExtendLHS->getOperand (0 );
2138+ SDValue ExtendInRHS = ExtendRHS->getOperand (0 );
2139+ bool IsSigned = ExtendLHS->getOpcode () == ISD::SIGN_EXTEND;
2140+ unsigned LowOpc =
2141+ IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
2142+ unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
2143+ : WebAssemblyISD::EXTEND_HIGH_U;
2144+ SDValue LowLHS;
2145+ SDValue LowRHS;
2146+ SDValue HighLHS;
2147+ SDValue HighRHS;
2148+
2149+ auto AssignInputs = [&](MVT VT) {
2150+ LowLHS = DAG.getNode (LowOpc, DL, VT, ExtendInLHS);
2151+ LowRHS = DAG.getNode (LowOpc, DL, VT, ExtendInRHS);
2152+ HighLHS = DAG.getNode (HighOpc, DL, VT, ExtendInLHS);
2153+ HighRHS = DAG.getNode (HighOpc, DL, VT, ExtendInRHS);
2154+ };
21452155
2146- // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2147- SDValue LowLHS = DAG.getNode (LowOpc, DL, MVT::v4i32, ExtendInLHS);
2148- SDValue LowRHS = DAG.getNode (LowOpc, DL, MVT::v4i32, ExtendInRHS);
2149- SDValue HighLHS = DAG.getNode (HighOpc, DL, MVT::v4i32, ExtendInLHS);
2150- SDValue HighRHS = DAG.getNode (HighOpc, DL, MVT::v4i32, ExtendInRHS);
2156+ if (ExtendInLHS->getValueType (0 ) == MVT::v8i16) {
2157+ if (IsSigned) {
2158+ // i32x4.dot_i16x8_s
2159+ SDValue Dot = DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32,
2160+ ExtendInLHS, ExtendInRHS);
2161+ return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Dot);
2162+ }
21512163
2152- SDValue MulLow = DAG.getNode (ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
2153- SDValue MulHigh = DAG.getNode (ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
2154- SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
2155- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2164+ // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2165+ MVT VT = MVT::v4i32;
2166+ AssignInputs (VT);
2167+ SDValue MulLow = DAG.getNode (ISD::MUL, DL, VT, LowLHS, LowRHS);
2168+ SDValue MulHigh = DAG.getNode (ISD::MUL, DL, VT, HighLHS, HighRHS);
2169+ SDValue Add = DAG.getNode (ISD::ADD, DL, VT, MulLow, MulHigh);
2170+ return DAG.getNode (ISD::ADD, DL, VT, N->getOperand (1 ), Add);
2171+ } else {
2172+ assert (ExtendInLHS->getValueType (0 ) == MVT::v16i8 &&
2173+ " expected v16i8 input types" );
2174+ AssignInputs (MVT::v8i16);
2175+ // Lower to a wider tree, using twice the operations compared to above.
2176+ if (IsSigned) {
2177+ // Use two dots
2178+ SDValue DotLHS =
2179+ DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2180+ SDValue DotRHS =
2181+ DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2182+ SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2183+ return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2184+ }
2185+
2186+ SDValue MulLow = DAG.getNode (ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2187+ SDValue MulHigh = DAG.getNode (ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2188+
2189+ SDValue AddLow = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2190+ MVT::v4i32, MulLow);
2191+ SDValue AddHigh = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2192+ MVT::v4i32, MulHigh);
2193+ SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2194+ return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2195+ }
21562196 } else {
2157- assert (ExtendInLHS->getValueType (0 ) == MVT::v16i8 &&
2158- " expected v16i8 input types" );
2159- // Lower to a wider tree, using twice the operations compared to above.
2160- if (IsSigned) {
2161- // Use two dots
2162- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
2163- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
2164- SDValue LowLHS = DAG.getNode (LowOpc, DL, MVT::v8i16, ExtendInLHS);
2165- SDValue LowRHS = DAG.getNode (LowOpc, DL, MVT::v8i16, ExtendInRHS);
2166- SDValue HighLHS = DAG.getNode (HighOpc, DL, MVT::v8i16, ExtendInLHS);
2167- SDValue HighRHS = DAG.getNode (HighOpc, DL, MVT::v8i16, ExtendInRHS);
2168- SDValue DotLHS =
2169- DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2170- SDValue DotRHS =
2171- DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2172- SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2197+ // Accumulate the input using extadd_pairwise.
2198+ assert (ISD::isExtOpcode (Input.getOpcode ()) && " expected extend" );
2199+ bool IsSigned = Input->getOpcode () == ISD::SIGN_EXTEND;
2200+ unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
2201+ : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
2202+ SDValue ExtendIn = Input->getOperand (0 );
2203+ if (ExtendIn->getValueType (0 ) == MVT::v8i16) {
2204+ SDValue Add = DAG.getNode (PairwiseOpc, DL, MVT::v4i32, ExtendIn);
21732205 return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
21742206 }
21752207
2176- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
2177- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
2178- SDValue LowLHS = DAG.getNode (LowOpc, DL, MVT::v8i16, ExtendInLHS);
2179- SDValue LowRHS = DAG.getNode (LowOpc, DL, MVT::v8i16, ExtendInRHS);
2180- SDValue HighLHS = DAG.getNode (HighOpc, DL, MVT::v8i16, ExtendInLHS);
2181- SDValue HighRHS = DAG.getNode (HighOpc, DL, MVT::v8i16, ExtendInRHS);
2182-
2183- SDValue MulLow = DAG.getNode (ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2184- SDValue MulHigh = DAG.getNode (ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2185-
2186- SDValue AddLow =
2187- DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, MVT::v4i32, MulLow);
2188- SDValue AddHigh = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2189- MVT::v4i32, MulHigh);
2190- SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2208+ assert (ExtendIn->getValueType (0 ) == MVT::v16i8 &&
2209+ " expected v16i8 input types" );
2210+ SDValue Add =
2211+ DAG.getNode (PairwiseOpc, DL, MVT::v4i32,
2212+ DAG.getNode (PairwiseOpc, DL, MVT::v8i16, ExtendIn));
21912213 return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
21922214 }
21932215}
0 commit comments