@@ -422,24 +422,30 @@ bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
422
422
return true ;
423
423
424
424
EVT VT = EVT::getEVT (I->getType ());
425
+ if (VT.getSizeInBits () > 128 )
426
+ return true ;
427
+
425
428
auto Op1 = I->getOperand (1 );
426
429
427
430
if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
428
- if (InstructionOpcodeToISD (InputInst->getOpcode ()) != ISD::MUL)
429
- return true ;
430
-
431
- if (isa<Instruction>(InputInst->getOperand (0 )) &&
432
- isa<Instruction>(InputInst->getOperand (1 ))) {
433
- // dot only supports signed inputs but also support lowering unsigned.
434
- if (cast<Instruction>(InputInst->getOperand (0 ))->getOpcode () !=
435
- cast<Instruction>(InputInst->getOperand (1 ))->getOpcode ())
436
- return true ;
437
-
438
- EVT Op1VT = EVT::getEVT (Op1->getType ());
439
- if (Op1VT.getVectorElementType () == VT.getVectorElementType () &&
440
- ((VT.getVectorElementCount () * 2 == Op1VT.getVectorElementCount ()) ||
441
- (VT.getVectorElementCount () * 4 == Op1VT.getVectorElementCount ())))
442
- return false ;
431
+ unsigned Opcode = InstructionOpcodeToISD (InputInst->getOpcode ());
432
+ if (Opcode == ISD::MUL) {
433
+ if (isa<Instruction>(InputInst->getOperand (0 )) &&
434
+ isa<Instruction>(InputInst->getOperand (1 ))) {
435
+ // dot only supports signed inputs but also support lowering unsigned.
436
+ if (cast<Instruction>(InputInst->getOperand (0 ))->getOpcode () !=
437
+ cast<Instruction>(InputInst->getOperand (1 ))->getOpcode ())
438
+ return true ;
439
+
440
+ EVT Op1VT = EVT::getEVT (Op1->getType ());
441
+ if (Op1VT.getVectorElementType () == VT.getVectorElementType () &&
442
+ ((VT.getVectorElementCount () * 2 ==
443
+ Op1VT.getVectorElementCount ()) ||
444
+ (VT.getVectorElementCount () * 4 == Op1VT.getVectorElementCount ())))
445
+ return false ;
446
+ }
447
+ } else if (ISD::isExtOpcode (Opcode)) {
448
+ return false ;
443
449
}
444
450
}
445
451
return true ;
@@ -2117,77 +2123,93 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
2117
2123
2118
2124
assert (N->getValueType (0 ) == MVT::v4i32 && " can only support v4i32" );
2119
2125
SDLoc DL (N);
2120
- SDValue Mul = N->getOperand (2 );
2121
- assert (Mul->getOpcode () == ISD::MUL && " expected mul input" );
2122
-
2123
- SDValue ExtendLHS = Mul->getOperand (0 );
2124
- SDValue ExtendRHS = Mul->getOperand (1 );
2125
- assert ((ISD::isExtOpcode (ExtendLHS.getOpcode ()) &&
2126
- ISD::isExtOpcode (ExtendRHS.getOpcode ())) &&
2127
- " expected widening mul" );
2128
- assert (ExtendLHS.getOpcode () == ExtendRHS.getOpcode () &&
2129
- " expected mul to use the same extend for both operands" );
2130
-
2131
- SDValue ExtendInLHS = ExtendLHS->getOperand (0 );
2132
- SDValue ExtendInRHS = ExtendRHS->getOperand (0 );
2133
- bool IsSigned = ExtendLHS->getOpcode () == ISD::SIGN_EXTEND;
2134
-
2135
- if (ExtendInLHS->getValueType (0 ) == MVT::v8i16) {
2136
- if (IsSigned) {
2137
- // i32x4.dot_i16x8_s
2138
- SDValue Dot = DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32,
2139
- ExtendInLHS, ExtendInRHS);
2140
- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Dot);
2141
- }
2142
2126
2143
- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
2144
- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
2127
+ SDValue Input = N->getOperand (2 );
2128
+ if (Input->getOpcode () == ISD::MUL) {
2129
+ SDValue ExtendLHS = Input->getOperand (0 );
2130
+ SDValue ExtendRHS = Input->getOperand (1 );
2131
+ assert ((ISD::isExtOpcode (ExtendLHS.getOpcode ()) &&
2132
+ ISD::isExtOpcode (ExtendRHS.getOpcode ())) &&
2133
+ " expected widening mul or add" );
2134
+ assert (ExtendLHS.getOpcode () == ExtendRHS.getOpcode () &&
2135
+ " expected binop to use the same extend for both operands" );
2136
+
2137
+ SDValue ExtendInLHS = ExtendLHS->getOperand (0 );
2138
+ SDValue ExtendInRHS = ExtendRHS->getOperand (0 );
2139
+ bool IsSigned = ExtendLHS->getOpcode () == ISD::SIGN_EXTEND;
2140
+ unsigned LowOpc =
2141
+ IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
2142
+ unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
2143
+ : WebAssemblyISD::EXTEND_HIGH_U;
2144
+ SDValue LowLHS;
2145
+ SDValue LowRHS;
2146
+ SDValue HighLHS;
2147
+ SDValue HighRHS;
2148
+
2149
+ auto AssignInputs = [&](MVT VT) {
2150
+ LowLHS = DAG.getNode (LowOpc, DL, VT, ExtendInLHS);
2151
+ LowRHS = DAG.getNode (LowOpc, DL, VT, ExtendInRHS);
2152
+ HighLHS = DAG.getNode (HighOpc, DL, VT, ExtendInLHS);
2153
+ HighRHS = DAG.getNode (HighOpc, DL, VT, ExtendInRHS);
2154
+ };
2145
2155
2146
- // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2147
- SDValue LowLHS = DAG.getNode (LowOpc, DL, MVT::v4i32, ExtendInLHS);
2148
- SDValue LowRHS = DAG.getNode (LowOpc, DL, MVT::v4i32, ExtendInRHS);
2149
- SDValue HighLHS = DAG.getNode (HighOpc, DL, MVT::v4i32, ExtendInLHS);
2150
- SDValue HighRHS = DAG.getNode (HighOpc, DL, MVT::v4i32, ExtendInRHS);
2156
+ if (ExtendInLHS->getValueType (0 ) == MVT::v8i16) {
2157
+ if (IsSigned) {
2158
+ // i32x4.dot_i16x8_s
2159
+ SDValue Dot = DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32,
2160
+ ExtendInLHS, ExtendInRHS);
2161
+ return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Dot);
2162
+ }
2151
2163
2152
- SDValue MulLow = DAG.getNode (ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
2153
- SDValue MulHigh = DAG.getNode (ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
2154
- SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
2155
- return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2164
+ // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
2165
+ MVT VT = MVT::v4i32;
2166
+ AssignInputs (VT);
2167
+ SDValue MulLow = DAG.getNode (ISD::MUL, DL, VT, LowLHS, LowRHS);
2168
+ SDValue MulHigh = DAG.getNode (ISD::MUL, DL, VT, HighLHS, HighRHS);
2169
+ SDValue Add = DAG.getNode (ISD::ADD, DL, VT, MulLow, MulHigh);
2170
+ return DAG.getNode (ISD::ADD, DL, VT, N->getOperand (1 ), Add);
2171
+ } else {
2172
+ assert (ExtendInLHS->getValueType (0 ) == MVT::v16i8 &&
2173
+ " expected v16i8 input types" );
2174
+ AssignInputs (MVT::v8i16);
2175
+ // Lower to a wider tree, using twice the operations compared to above.
2176
+ if (IsSigned) {
2177
+ // Use two dots
2178
+ SDValue DotLHS =
2179
+ DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2180
+ SDValue DotRHS =
2181
+ DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2182
+ SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2183
+ return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2184
+ }
2185
+
2186
+ SDValue MulLow = DAG.getNode (ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2187
+ SDValue MulHigh = DAG.getNode (ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2188
+
2189
+ SDValue AddLow = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2190
+ MVT::v4i32, MulLow);
2191
+ SDValue AddHigh = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2192
+ MVT::v4i32, MulHigh);
2193
+ SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2194
+ return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2195
+ }
2156
2196
} else {
2157
- assert (ExtendInLHS->getValueType (0 ) == MVT::v16i8 &&
2158
- " expected v16i8 input types" );
2159
- // Lower to a wider tree, using twice the operations compared to above.
2160
- if (IsSigned) {
2161
- // Use two dots
2162
- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
2163
- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
2164
- SDValue LowLHS = DAG.getNode (LowOpc, DL, MVT::v8i16, ExtendInLHS);
2165
- SDValue LowRHS = DAG.getNode (LowOpc, DL, MVT::v8i16, ExtendInRHS);
2166
- SDValue HighLHS = DAG.getNode (HighOpc, DL, MVT::v8i16, ExtendInLHS);
2167
- SDValue HighRHS = DAG.getNode (HighOpc, DL, MVT::v8i16, ExtendInRHS);
2168
- SDValue DotLHS =
2169
- DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
2170
- SDValue DotRHS =
2171
- DAG.getNode (WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
2172
- SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
2197
+ // Accumulate the input using extadd_pairwise.
2198
+ assert (ISD::isExtOpcode (Input.getOpcode ()) && " expected extend" );
2199
+ bool IsSigned = Input->getOpcode () == ISD::SIGN_EXTEND;
2200
+ unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
2201
+ : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
2202
+ SDValue ExtendIn = Input->getOperand (0 );
2203
+ if (ExtendIn->getValueType (0 ) == MVT::v8i16) {
2204
+ SDValue Add = DAG.getNode (PairwiseOpc, DL, MVT::v4i32, ExtendIn);
2173
2205
return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2174
2206
}
2175
2207
2176
- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
2177
- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
2178
- SDValue LowLHS = DAG.getNode (LowOpc, DL, MVT::v8i16, ExtendInLHS);
2179
- SDValue LowRHS = DAG.getNode (LowOpc, DL, MVT::v8i16, ExtendInRHS);
2180
- SDValue HighLHS = DAG.getNode (HighOpc, DL, MVT::v8i16, ExtendInLHS);
2181
- SDValue HighRHS = DAG.getNode (HighOpc, DL, MVT::v8i16, ExtendInRHS);
2182
-
2183
- SDValue MulLow = DAG.getNode (ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
2184
- SDValue MulHigh = DAG.getNode (ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
2185
-
2186
- SDValue AddLow =
2187
- DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, MVT::v4i32, MulLow);
2188
- SDValue AddHigh = DAG.getNode (WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
2189
- MVT::v4i32, MulHigh);
2190
- SDValue Add = DAG.getNode (ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
2208
+ assert (ExtendIn->getValueType (0 ) == MVT::v16i8 &&
2209
+ " expected v16i8 input types" );
2210
+ SDValue Add =
2211
+ DAG.getNode (PairwiseOpc, DL, MVT::v4i32,
2212
+ DAG.getNode (PairwiseOpc, DL, MVT::v8i16, ExtendIn));
2191
2213
return DAG.getNode (ISD::ADD, DL, MVT::v4i32, N->getOperand (1 ), Add);
2192
2214
}
2193
2215
}
0 commit comments