Skip to content

Commit cb9aac0

Browse files
committed
Added combine support for dot
1 parent 4d304c8 commit cb9aac0

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
192192
// Combine wide-vector muls, with extend inputs, to extmul_half.
193193
setTargetDAGCombine(ISD::MUL);
194194

195+
// Combine add with vector shuffle of muls to dots
196+
setTargetDAGCombine(ISD::ADD);
197+
195198
// Combine vector mask reductions into alltrue/anytrue
196199
setTargetDAGCombine(ISD::SETCC);
197200

@@ -3436,6 +3439,52 @@ static SDValue performSETCCCombine(SDNode *N,
34363439
return SDValue();
34373440
}
34383441

3442+
static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
3443+
assert(N->getOpcode() == ISD::ADD);
3444+
EVT VT = N->getValueType(0);
3445+
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
3446+
3447+
if (VT != MVT::v4i32)
3448+
return SDValue();
3449+
3450+
auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
3451+
if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
3452+
return SDValue();
3453+
if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
3454+
return SDValue();
3455+
return V;
3456+
};
3457+
auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
3458+
auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
3459+
// two SDValues must be muls
3460+
if (!ShuffleA || !ShuffleB)
3461+
return SDValue();
3462+
3463+
if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
3464+
ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
3465+
return SDValue();
3466+
3467+
auto IsMulExtend =
3468+
[](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
3469+
if (V.getOpcode() != ISD::MUL)
3470+
return {};
3471+
3472+
auto V0 = V.getOperand(0), V1 = V.getOperand(1);
3473+
if (V0.getOpcode() != I || V1.getOpcode() != I)
3474+
return {};
3475+
return {V0.getOperand(0), V1.getOperand(0)};
3476+
};
3477+
3478+
auto [LowA, LowB] =
3479+
IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
3480+
auto [HighA, HighB] =
3481+
IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);
3482+
3483+
if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
3484+
return SDValue();
3485+
3486+
return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
3487+
}
34393488
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
34403489
assert(N->getOpcode() == ISD::MUL);
34413490
EVT VT = N->getValueType(0);
@@ -3558,5 +3607,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
35583607
}
35593608
case ISD::MUL:
35603609
return performMulCombine(N, DCI.DAG);
3610+
case ISD::ADD:
3611+
return performAddCombine(N, DCI.DAG);
35613612
}
35623613
}

llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,10 @@ target triple = "wasm32-unknown-unknown"
55
define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
66
; CHECK-LABEL: dot:
77
; CHECK: .functype dot (v128, v128) -> (v128)
8-
; CHECK-NEXT: .local v128
98
; CHECK-NEXT: # %bb.0:
109
; CHECK-NEXT: local.get 0
1110
; CHECK-NEXT: local.get 1
12-
; CHECK-NEXT: i32x4.extmul_low_i16x8_s
13-
; CHECK-NEXT: local.tee 2
14-
; CHECK-NEXT: local.get 0
15-
; CHECK-NEXT: local.get 1
16-
; CHECK-NEXT: i32x4.extmul_high_i16x8_s
17-
; CHECK-NEXT: local.tee 1
18-
; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
19-
; CHECK-NEXT: local.get 2
20-
; CHECK-NEXT: local.get 1
21-
; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
22-
; CHECK-NEXT: i32x4.add
11+
; CHECK-NEXT: i32x4.dot_i16x8_s
2312
; CHECK-NEXT: # fallthrough-return
2413
%sext1 = sext <8 x i16> %a to <8 x i32>
2514
%sext2 = sext <8 x i16> %b to <8 x i32>

0 commit comments

Comments
 (0)