@@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
192192 // Combine wide-vector muls, with extend inputs, to extmul_half.
193193 setTargetDAGCombine (ISD::MUL);
194194
195+ // Combine add with vector shuffle of muls to dots
196+ setTargetDAGCombine (ISD::ADD);
197+
195198 // Combine vector mask reductions into alltrue/anytrue
196199 setTargetDAGCombine (ISD::SETCC);
197200
@@ -3436,6 +3439,52 @@ static SDValue performSETCCCombine(SDNode *N,
34363439 return SDValue ();
34373440}
34383441
3442+ static SDValue performAddCombine (SDNode *N, SelectionDAG &DAG) {
3443+ assert (N->getOpcode () == ISD::ADD);
3444+ EVT VT = N->getValueType (0 );
3445+ SDValue N0 = N->getOperand (0 ), N1 = N->getOperand (1 );
3446+
3447+ if (VT != MVT::v4i32)
3448+ return SDValue ();
3449+
3450+ auto IsShuffleWithMask = [](SDValue V, ArrayRef<int > ShuffleValue) {
3451+ if (V.getOpcode () != ISD::VECTOR_SHUFFLE)
3452+ return SDValue ();
3453+ if (cast<ShuffleVectorSDNode>(V)->getMask () != ShuffleValue)
3454+ return SDValue ();
3455+ return V;
3456+ };
3457+ auto ShuffleA = IsShuffleWithMask (N0, {0 , 2 , 4 , 6 });
3458+ auto ShuffleB = IsShuffleWithMask (N1, {1 , 3 , 5 , 7 });
3459+ // two SDValues must be muls
3460+ if (!ShuffleA || !ShuffleB)
3461+ return SDValue ();
3462+
3463+ if (ShuffleA.getOperand (0 ) != ShuffleB.getOperand (0 ) ||
3464+ ShuffleA.getOperand (1 ) != ShuffleB.getOperand (1 ))
3465+ return SDValue ();
3466+
3467+ auto IsMulExtend =
3468+ [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
3469+ if (V.getOpcode () != ISD::MUL)
3470+ return {};
3471+
3472+ auto V0 = V.getOperand (0 ), V1 = V.getOperand (1 );
3473+ if (V0.getOpcode () != I || V1.getOpcode () != I)
3474+ return {};
3475+ return {V0.getOperand (0 ), V1.getOperand (0 )};
3476+ };
3477+
3478+ auto [LowA, LowB] =
3479+ IsMulExtend (ShuffleA.getOperand (0 ), WebAssemblyISD::EXTEND_LOW_S);
3480+ auto [HighA, HighB] =
3481+ IsMulExtend (ShuffleA.getOperand (1 ), WebAssemblyISD::EXTEND_HIGH_S);
3482+
3483+ if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
3484+ return SDValue ();
3485+
3486+ return DAG.getNode (WebAssemblyISD::DOT, SDLoc (N), MVT::v4i32, LowA, LowB);
3487+ }
34393488static SDValue performMulCombine (SDNode *N, SelectionDAG &DAG) {
34403489 assert (N->getOpcode () == ISD::MUL);
34413490 EVT VT = N->getValueType (0 );
@@ -3558,5 +3607,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
35583607 }
35593608 case ISD::MUL:
35603609 return performMulCombine (N, DCI.DAG );
3610+ case ISD::ADD:
3611+ return performAddCombine (N, DCI.DAG );
35613612 }
35623613}
0 commit comments