Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine wide-vector muls, with extend inputs, to extmul_half.
setTargetDAGCombine(ISD::MUL);

// Combine add with vector shuffle of muls to dots
setTargetDAGCombine(ISD::ADD);

// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);

Expand Down Expand Up @@ -3436,6 +3439,53 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}

static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::ADD);
EVT VT = N->getValueType(0);
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);

if (VT != MVT::v4i32)
return SDValue();

auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
return SDValue();
if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
return SDValue();
return V;
};
auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
// two SDValues must be muls
if (!ShuffleA || !ShuffleB)
return SDValue();

if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
return SDValue();

auto IsMulExtend =
[](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
if (V.getOpcode() != ISD::MUL)
return {};

auto V0 = V.getOperand(0), V1 = V.getOperand(1);
if (V0.getOpcode() != I || V1.getOpcode() != I)
return {};
return {V0.getOperand(0), V1.getOperand(0)};
};

auto [LowA, LowB] =
IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
auto [HighA, HighB] =
IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);

if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
return SDValue();

return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
}

static SDValue TryWideExtMulCombine(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
if (VT != MVT::v8i32 && VT != MVT::v16i32)
Expand Down Expand Up @@ -3597,5 +3647,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::MUL:
return performMulCombine(N, DCI);
case ISD::ADD:
return performAddCombine(N, DCI.DAG);
}
}
21 changes: 21 additions & 0 deletions llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mattr=+simd128 | FileCheck %s

target triple = "wasm32-unknown-unknown"
define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: dot:
; CHECK: .functype dot (v128, v128) -> (v128)
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i32x4.dot_i16x8_s
; CHECK-NEXT: # fallthrough-return
%sext1 = sext <8 x i16> %a to <8 x i32>
%sext2 = sext <8 x i16> %b to <8 x i32>
%mul = mul nsw <8 x i32> %sext1, %sext2
%shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
%shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
%res = add <4 x i32> %shuffle1, %shuffle2
ret <4 x i32> %res
}

Loading